Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAG MVP #374

Merged
merged 7 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion apps/experiments/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@

from celery.app import shared_task
from langchain.schema import AIMessage, HumanMessage
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyMuPDFLoader, TextLoader
from taskbadger.celery import Task as TaskbadgerTask

from apps.channels.datamodels import WebMessage
from apps.chat.bots import create_conversation
from apps.chat.channels import WebChannel
from apps.experiments.models import ExperimentSession, PromptBuilderHistory, SourceMaterial
from apps.experiments.models import Experiment, ExperimentSession, PromptBuilderHistory, SourceMaterial
from apps.service_providers.models import LlmProvider
from apps.users.models import CustomUser
from apps.utils.taskbadger import update_taskbadger_data
from apps.vectordb.vectorstore import PGVector


@shared_task(bind=True, base=TaskbadgerTask)
Expand All @@ -23,6 +26,44 @@ def get_response_for_webchat_task(self, experiment_session_id: int, message_text
return message_handler.new_user_message(message)


@shared_task(bind=True, base=TaskbadgerTask)
def store_rag_embedding(self, experiment_id: int) -> None:
experiment = Experiment.objects.get(id=experiment_id)
file_path = experiment.files.all().last().file.path
splits = load_rag_file(file_path)
embeddings_model = experiment.get_llm_service().get_openai_embeddings()
PGVector.from_texts(splits, embeddings_model, None, experiment)


def load_rag_file(file_path: str) -> list[str]:
"""
Loads a text file of any supported type (PDF, TXT, HTML) into Langchain.

Args:
file_path (str): The path to the text file.

Returns:
str_splits: A list of strings from Langchain Document objects
containing the loaded page_content.
"""

# Automatically detect loader based on file extension if not provided
extension = file_path.split(".")[-1].lower()
if extension == "pdf":
loader = PyMuPDFLoader(file_path, extract_images=False)
elif extension in ("txt", "text"):
loader = TextLoader(file_path)
else:
raise ValueError(f"Unsupported file type: {extension}")

# Load the text file using the appropriate loader
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = text_splitter.split_documents(documents)
str_splits = [s.page_content for s in splits[0:10]]
return str_splits


@shared_task
def get_prompt_builder_response_task(team_id: int, user_id, data_dict: dict) -> dict[str, str | int]:
llm_service = LlmProvider.objects.get(id=data_dict["provider"]).get_llm_service()
Expand Down
7 changes: 6 additions & 1 deletion apps/experiments/views/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from apps.experiments.helpers import get_real_user_or_none
from apps.experiments.models import Experiment, ExperimentSession, Participant, SessionStatus, SyntheticVoice
from apps.experiments.tables import ExperimentSessionsTable, ExperimentTable
from apps.experiments.tasks import get_response_for_webchat_task
from apps.experiments.tasks import get_response_for_webchat_task, store_rag_embedding
from apps.experiments.views.prompt import PROMPT_DATA_SESSION_KEY
from apps.files.forms import get_file_formset
from apps.files.views import BaseAddFileHtmxView, BaseDeleteFileView
Expand Down Expand Up @@ -220,6 +220,10 @@ def _validate_prompt_variables(form_data):
available_variables = set()
if form_data.get("source_material"):
available_variables.add("source_material")
# available_variables below should be added by making a
# db request to check if there are any RAG files uploaded
available_variables.add("context")
available_variables.add("input")
missing_vars = required_variables - available_variables
known_vars = {"source_material"}
if missing_vars:
Expand Down Expand Up @@ -361,6 +365,7 @@ def form_valid(self, form):
experiment = get_object_or_404(Experiment, team=self.request.team, pk=self.kwargs["pk"])
file = super().form_valid(form)
experiment.files.add(file)
store_rag_embedding(experiment.id)
return file

def get_delete_url(self, file):
Expand Down
9 changes: 9 additions & 0 deletions apps/service_providers/llm_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from langchain_community.chat_models import ChatAnthropic
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.language_models import BaseLanguageModel
from langchain_openai import OpenAIEmbeddings
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
from openai import OpenAI
from openai._base_client import SyncAPIClient
Expand Down Expand Up @@ -42,6 +43,14 @@ class OpenAILlmService(LlmService):
openai_api_base: str = None
openai_organization: str = None

def get_openai_embeddings(self, model="text-embedding-3-small") -> OpenAIEmbeddings:
return OpenAIEmbeddings(
openai_api_key=self.openai_api_key,
openai_api_base=self.openai_api_base,
openai_organization=self.openai_organization,
model=model,
)

def get_raw_client(self) -> OpenAI:
return OpenAI(api_key=self.openai_api_key, organization=self.openai_organization, base_url=self.openai_api_base)

Expand Down
23 changes: 22 additions & 1 deletion apps/service_providers/llm_service/runnables.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from apps.chat.conversation import compress_chat_history
from apps.chat.models import ChatMessage, ChatMessageType
from apps.experiments.models import Experiment, ExperimentSession
from apps.vectordb.vectorstore import PGVector

logger = logging.getLogger(__name__)

Expand All @@ -51,7 +52,8 @@ def create_experiment_runnable(experiment: Experiment, session: ExperimentSessio
assert experiment.llm_provider, "Experiment must have an LLM provider"
if experiment.tools_enabled:
return AgentExperimentRunnable(experiment=experiment, session=session)

if experiment.files.exists():
return RagExperimentRunnable(experiment=experiment, session=session)
return SimpleExperimentRunnable(experiment=experiment, session=session)


Expand Down Expand Up @@ -219,6 +221,25 @@ def _build_chain(self) -> Runnable[dict[str, Any], str]:
)


class RagExperimentRunnable(ExperimentRunnable):
def _build_chain(self) -> Runnable[dict[str, Any], str]:
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)

model = self.llm_service.get_chat_model(self.experiment.llm, self.experiment.temperature)
embeddings = self.experiment.get_llm_service().get_openai_embeddings()
retriever = PGVector(self.experiment, embeddings).as_retriever()
return (
{"context": retriever | format_docs, "input": RunnablePassthrough()}
| RunnablePassthrough.assign(
history=RunnableLambda(self.memory.load_memory_variables) | itemgetter("history")
)
| self.prompt
| model
| StrOutputParser()
)


class AgentExperimentRunnable(ExperimentRunnable):
def _parse_output(self, output):
return output.get("output", "")
Expand Down
2 changes: 2 additions & 0 deletions requirements/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,5 @@ turn-python>=0.2.0
jinja2
django-taggit
pgvector
PyMuPDF
langchainhub
Loading