Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAG MVP #374

Merged
merged 7 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions apps/experiments/views/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
from django.views.decorators.http import require_POST
from django.views.generic import CreateView, UpdateView
from django_tables2 import SingleTableView
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyMuPDFLoader, TextLoader
from langchain_core.prompts import PromptTemplate
from langchain_openai import OpenAIEmbeddings
from waffle import flag_is_active

from apps.annotations.models import Tag
Expand Down Expand Up @@ -61,6 +64,7 @@
from apps.teams.decorators import login_and_team_required
from apps.teams.mixins import LoginAndTeamRequiredMixin
from apps.users.models import CustomUser
from apps.vectordb.vectorstore import PGVector


@login_and_team_required
Expand Down Expand Up @@ -361,8 +365,40 @@ def form_valid(self, form):
experiment = get_object_or_404(Experiment, team=self.request.team, pk=self.kwargs["pk"])
file = super().form_valid(form)
experiment.files.add(file)
file_path = experiment.files.all().last().file.path
splits = self.load_rag_file(file_path)
embeddings_model = OpenAIEmbeddings()
PGVector.from_texts(splits, embeddings_model, None, experiment)
snopoke marked this conversation as resolved.
Show resolved Hide resolved
return file

def load_rag_file(self, file_path: str) -> list[str]:
"""
Loads a text file of any supported type (PDF, TXT, HTML) into Langchain.

Args:
file_path (str): The path to the text file.

Returns:
str_splits: A list of strings from Langchain Document objects
containing the loaded page_content.
"""

# Automatically detect loader based on file extension if not provided
extension = file_path.split(".")[-1].lower()
if extension == "pdf":
loader = PyMuPDFLoader(file_path, extract_images=False)
elif extension in ("txt", "text"):
loader = TextLoader(file_path)
else:
raise ValueError(f"Unsupported file type: {extension}")

# Load the text file using the appropriate loader
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(documents)
str_splits = [s.page_content for s in splits[0:10]]
return str_splits

def get_delete_url(self, file):
return reverse("experiments:remove_file", args=[self.request.team.slug, self.kwargs["pk"], file.pk])

Expand Down
9 changes: 9 additions & 0 deletions apps/service_providers/llm_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain_community.chat_models import ChatAnthropic
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.language_models import BaseLanguageModel
from langchain_openai import OpenAIEmbeddings
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
from openai import OpenAI
from openai._base_client import SyncAPIClient
Expand Down Expand Up @@ -42,6 +43,14 @@ class OpenAILlmService(LlmService):
openai_api_base: str = None
openai_organization: str = None

def get_openai_embeddings(self, model="text-embedding-3-small") -> OpenAIEmbeddings:
return OpenAIEmbeddings(
openai_api_key=self.openai_api_key,
openai_api_base=self.openai_api_base,
openai_organization=self.openai_organization,
model=model,
)

def get_raw_client(self) -> OpenAI:
return OpenAI(api_key=self.openai_api_key, organization=self.openai_organization, base_url=self.openai_api_base)

Expand Down
25 changes: 24 additions & 1 deletion apps/service_providers/llm_service/runnables.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import openai
import pytz
from langchain import hub
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain.agents.openai_assistant.base import OpenAIAssistantFinish
from langchain.memory import ConversationBufferMemory
Expand All @@ -29,6 +30,7 @@
from apps.chat.conversation import compress_chat_history
from apps.chat.models import ChatMessage, ChatMessageType
from apps.experiments.models import Experiment, ExperimentSession
from apps.vectordb.vectorstore import PGVector

logger = logging.getLogger(__name__)

Expand All @@ -51,7 +53,8 @@ def create_experiment_runnable(experiment: Experiment, session: ExperimentSessio
assert experiment.llm_provider, "Experiment must have an LLM provider"
if experiment.tools_enabled:
return AgentExperimentRunnable(experiment=experiment, session=session)

if experiment.files.all():
snopoke marked this conversation as resolved.
Show resolved Hide resolved
return RagExperimentRunnable(experiment=experiment, session=session)
return SimpleExperimentRunnable(experiment=experiment, session=session)


Expand Down Expand Up @@ -219,6 +222,26 @@ def _build_chain(self) -> Runnable[dict[str, Any], str]:
)


class RagExperimentRunnable(ExperimentRunnable):
def _build_chain(self) -> Runnable[dict[str, Any], str]:
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)

model = self.llm_service.get_chat_model(self.experiment.llm, self.experiment.temperature)
embeddings = self.experiment.get_llm_service().get_openai_embeddings()
retriever = PGVector(self.experiment, embeddings).as_retriever()
prompt = hub.pull("rlm/rag-prompt")
snopoke marked this conversation as resolved.
Show resolved Hide resolved
return (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| RunnablePassthrough.assign(
history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history")
)
| prompt
| model
| StrOutputParser()
)


class AgentExperimentRunnable(ExperimentRunnable):
def _parse_output(self, output):
return output.get("output", "")
Expand Down
2 changes: 2 additions & 0 deletions requirements/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,5 @@ turn-python>=0.2.0
jinja2
django-taggit
pgvector
PyMuPDF
langchainhub