-
-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(integration): implementation (#2191)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate):
- Loading branch information
1 parent
6383918
commit ba5ef60
Showing
17 changed files
with
657 additions
and
194 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,3 +79,5 @@ paulgraham.py | |
supabase/seed-airwallex.sql | ||
airwallexpayouts.py | ||
application.log | ||
backend/celerybeat-schedule.db | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#!/bin/bash | ||
|
||
SESSION_NAME="my_services" | ||
|
||
start_services() { | ||
# Create a new tmux session | ||
tmux new-session -d -s $SESSION_NAME | ||
|
||
# Split the window into panes for each service | ||
tmux split-window -h | ||
tmux split-window -v | ||
tmux select-pane -t 0 | ||
tmux split-window -v | ||
|
||
# Start each service in its pane | ||
tmux send-keys -t $SESSION_NAME:0.0 'echo "Starting backend-core...";pipenv run uvicorn main:app --reload --host 0.0.0.0 --port 5050 --workers 6' C-m | ||
tmux send-keys -t $SESSION_NAME:0.1 'echo "Starting worker...";pipenv run celery -A celery_worker worker -l info' C-m | ||
tmux send-keys -t $SESSION_NAME:0.2 'echo "Starting beat...";pipenv run celery -A celery_worker beat -l info' C-m | ||
tmux send-keys -t $SESSION_NAME:0.3 'echo "Starting flower...";pipenv run celery -A celery_worker flower -l info --port=5555' C-m | ||
|
||
echo "Services started in tmux session '$SESSION_NAME'" | ||
echo "Use 'tmux attach-session -t $SESSION_NAME' to view logs" | ||
} | ||
|
||
stop_services() { | ||
# Kill the tmux session | ||
tmux kill-session -t $SESSION_NAME | ||
echo "Services stopped" | ||
} | ||
|
||
view_logs() { | ||
# Attach to the tmux session to view logs | ||
tmux attach-session -t $SESSION_NAME | ||
} | ||
|
||
if [ "$1" == "start" ]; then | ||
start_services | ||
elif [ "$1" == "stop" ]; then | ||
stop_services | ||
elif [ "$1" == "logs" ]; then | ||
view_logs | ||
else | ||
echo "Usage: $0 {start|stop|logs}" | ||
fi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import json | ||
from typing import AsyncIterable | ||
from uuid import UUID | ||
|
||
from langchain_community.chat_models import ChatLiteLLM | ||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | ||
from modules.brain.knowledge_brain_qa import KnowledgeBrainQA | ||
from modules.chat.dto.chats import ChatQuestion | ||
|
||
|
||
class GPT4Brain(KnowledgeBrainQA): | ||
"""This is the Notion brain class. it is a KnowledgeBrainQA has the data is stored locally. | ||
It is going to call the Data Store internally to get the data. | ||
Args: | ||
KnowledgeBrainQA (_type_): A brain that store the knowledge internaly | ||
""" | ||
|
||
def __init__( | ||
self, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
**kwargs, | ||
) | ||
|
||
def get_chain(self): | ||
|
||
prompt = ChatPromptTemplate.from_messages( | ||
[ | ||
("system", "You are GPT-4 powered by Quivr. You are an assistant."), | ||
MessagesPlaceholder(variable_name="chat_history"), | ||
("human", "{question}"), | ||
] | ||
) | ||
|
||
chain = prompt | ChatLiteLLM( | ||
model="gpt-4-0125-preview", max_tokens=self.max_tokens | ||
) | ||
|
||
return chain | ||
|
||
async def generate_stream( | ||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True | ||
) -> AsyncIterable: | ||
conversational_qa_chain = self.get_chain() | ||
transformed_history, streamed_chat_history = ( | ||
self.initialize_streamed_chat_history(chat_id, question) | ||
) | ||
response_tokens = [] | ||
|
||
async for chunk in conversational_qa_chain.astream( | ||
{ | ||
"question": question.question, | ||
"chat_history": transformed_history, | ||
} | ||
): | ||
response_tokens.append(chunk.content) | ||
streamed_chat_history.assistant = chunk.content | ||
yield f"data: {json.dumps(streamed_chat_history.dict())}" | ||
|
||
self.save_answer(question, response_tokens, streamed_chat_history, save_answer) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import json | ||
from typing import AsyncIterable | ||
from uuid import UUID | ||
|
||
from langchain_community.chat_models import ChatLiteLLM | ||
from langchain_community.utilities import SQLDatabase | ||
from langchain_core.output_parsers import StrOutputParser | ||
from langchain_core.prompts import ChatPromptTemplate | ||
from langchain_core.runnables import RunnablePassthrough | ||
from modules.brain.integrations.SQL.SQL_connector import SQLConnector | ||
from modules.brain.knowledge_brain_qa import KnowledgeBrainQA | ||
from modules.brain.repository.integration_brains import IntegrationBrain | ||
from modules.chat.dto.chats import ChatQuestion | ||
|
||
|
||
class SQLBrain(KnowledgeBrainQA, IntegrationBrain): | ||
"""This is the Notion brain class. it is a KnowledgeBrainQA has the data is stored locally. | ||
It is going to call the Data Store internally to get the data. | ||
Args: | ||
KnowledgeBrainQA (_type_): A brain that store the knowledge internaly | ||
""" | ||
|
||
uri: str = None | ||
db: SQLDatabase = None | ||
sql_connector: SQLConnector = None | ||
|
||
def __init__( | ||
self, | ||
**kwargs, | ||
): | ||
super().__init__( | ||
**kwargs, | ||
) | ||
self.sql_connector = SQLConnector(self.brain_id, self.user_id) | ||
|
||
def get_schema(self, _): | ||
return self.db.get_table_info() | ||
|
||
def run_query(self, query): | ||
return self.db.run(query) | ||
|
||
def get_chain(self): | ||
template = """Based on the table schema below, write a SQL query that would answer the user's question: | ||
{schema} | ||
Question: {question} | ||
SQL Query:""" | ||
prompt = ChatPromptTemplate.from_template(template) | ||
|
||
self.db = SQLDatabase.from_uri(self.sql_connector.credentials["uri"]) | ||
|
||
model = ChatLiteLLM(model=self.model) | ||
|
||
sql_response = ( | ||
RunnablePassthrough.assign(schema=self.get_schema) | ||
| prompt | ||
| model.bind(stop=["\nSQLResult:"]) | ||
| StrOutputParser() | ||
) | ||
|
||
template = """Based on the table schema below, question, sql query, and sql response, write a natural language response and the query that was used to generate it.: | ||
{schema} | ||
Question: {question} | ||
SQL Query: {query} | ||
SQL Response: {response}""" | ||
prompt_response = ChatPromptTemplate.from_template(template) | ||
|
||
full_chain = ( | ||
RunnablePassthrough.assign(query=sql_response).assign( | ||
schema=self.get_schema, | ||
response=lambda x: self.db.run(x["query"]), | ||
) | ||
| prompt_response | ||
| model | ||
) | ||
|
||
return full_chain | ||
|
||
async def generate_stream( | ||
self, chat_id: UUID, question: ChatQuestion, save_answer: bool = True | ||
) -> AsyncIterable: | ||
|
||
conversational_qa_chain = self.get_chain() | ||
transformed_history, streamed_chat_history = ( | ||
self.initialize_streamed_chat_history(chat_id, question) | ||
) | ||
response_tokens = [] | ||
|
||
async for chunk in conversational_qa_chain.astream( | ||
{ | ||
"question": question.question, | ||
} | ||
): | ||
response_tokens.append(chunk.content) | ||
streamed_chat_history.assistant = chunk.content | ||
yield f"data: {json.dumps(streamed_chat_history.dict())}" | ||
|
||
self.save_answer(question, response_tokens, streamed_chat_history, save_answer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from logger import get_logger | ||
from modules.brain.entity.integration_brain import IntegrationEntity | ||
from modules.brain.repository.integration_brains import IntegrationBrain | ||
from modules.knowledge.repository.knowledge_interface import KnowledgeInterface | ||
from modules.knowledge.service.knowledge_service import KnowledgeService | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
class SQLConnector(IntegrationBrain): | ||
"""A class to interact with an SQL database""" | ||
|
||
credentials: dict[str, str] = None | ||
integration_details: IntegrationEntity = None | ||
brain_id: str = None | ||
user_id: str = None | ||
knowledge_service: KnowledgeInterface | ||
|
||
def __init__(self, brain_id: str, user_id: str): | ||
super().__init__() | ||
self.brain_id = brain_id | ||
self.user_id = user_id | ||
self._load_credentials() | ||
self.knowledge_service = KnowledgeService() | ||
|
||
def _load_credentials(self) -> dict[str, str]: | ||
"""Load the Notion credentials""" | ||
self.integration_details = self.get_integration_brain( | ||
self.brain_id, self.user_id | ||
) | ||
if self.credentials is None: | ||
logger.info("Loading Notion credentials") | ||
self.integration_details.credentials = { | ||
"uri": self.integration_details.settings.get("uri", "") | ||
} | ||
self.update_integration_brain( | ||
self.brain_id, self.user_id, self.integration_details | ||
) | ||
self.credentials = self.integration_details.credentials | ||
else: # pragma: no cover | ||
self.credentials = self.integration_details.credentials |
Empty file.
Oops, something went wrong.