From 9464707d4058c74fe8d03c57e04c4d416b98ed48 Mon Sep 17 00:00:00 2001 From: Mamadou DICKO <63923024+mamadoudicko@users.noreply.github.com> Date: Thu, 7 Sep 2023 17:23:31 +0200 Subject: [PATCH] feat: merge chat history with chat notifications (#1127) * feat: add chat_id to upload and crawl payload * feat(chat): return chat_history_with_notifications * feat: explicit notification status on create * feat: handle notifications in frontend * feat: delete chat notifications on chat delete request --- backend/models/databases/repository.py | 4 ++ .../databases/supabase/notifications.py | 17 +++++- .../get_chat_history_with_notifications.py | 57 +++++++++++++++++++ .../notification/get_chat_notifications.py | 14 +++++ .../notification/remove_chat_notifications.py | 11 ++++ backend/routes/chat_routes.py | 14 ++++- backend/routes/crawl_routes.py | 26 +++++---- backend/routes/upload_routes.py | 27 +++++---- backend/utils/parse_message_time.py | 5 ++ .../ActionsBar/hooks/useKnowledgeUploader.ts | 20 +++++-- .../[chatId]/hooks/useSelectedChatPage.ts | 4 +- frontend/app/chat/[chatId]/types/index.ts | 16 ++++++ .../utils/getMessagesFromChatHistory.ts | 11 ++++ frontend/lib/api/chat/chat.ts | 5 +- .../api/crawl/__tests__/useCrawlApi.test.ts | 5 +- frontend/lib/api/crawl/crawl.ts | 6 +- frontend/lib/api/upload/upload.ts | 6 +- 17 files changed, 213 insertions(+), 35 deletions(-) create mode 100644 backend/repository/chat/get_chat_history_with_notifications.py create mode 100644 backend/repository/notification/get_chat_notifications.py create mode 100644 backend/repository/notification/remove_chat_notifications.py create mode 100644 backend/utils/parse_message_time.py create mode 100644 frontend/app/chat/[chatId]/utils/getMessagesFromChatHistory.ts diff --git a/backend/models/databases/repository.py b/backend/models/databases/repository.py index 6c6d81dd6e66..43266fe2d177 100644 --- a/backend/models/databases/repository.py +++ b/backend/models/databases/repository.py @@ -232,6 +232,10 @@ def update_notification_by_id(self, id: UUID): def remove_notification_by_id(self, id: UUID): pass + @abstractmethod + def remove_notifications_by_chat_id(self, chat_id: UUID): + pass + @abstractmethod def get_notifications_by_chat_id(self, chat_id: UUID): pass diff --git a/backend/models/databases/supabase/notifications.py b/backend/models/databases/supabase/notifications.py index 9b6d47658178..2d6e96d2977b 100644 --- a/backend/models/databases/supabase/notifications.py +++ b/backend/models/databases/supabase/notifications.py @@ -93,6 +93,19 @@ def remove_notification_by_id( status="deleted", notification_id=notification_id ) + def remove_notifications_by_chat_id(self, chat_id: UUID) -> None: + """ + Remove all notifications for a chat + Args: + chat_id (UUID): The id of the chat + """ + ( + self.db.from_("notifications") + .delete() + .filter("chat_id", "eq", chat_id) + .execute() + ).data + def get_notifications_by_chat_id(self, chat_id: UUID) -> list[Notification]: """ Get all notifications for a chat @@ -102,9 +115,11 @@ def get_notifications_by_chat_id(self, chat_id: UUID) -> list[Notification]: Returns: list[Notification]: The notifications """ - return ( + notifications = ( self.db.from_("notifications") .select("*") .filter("chat_id", "eq", chat_id) .execute() ).data + + return [Notification(**notification) for notification in notifications] diff --git a/backend/repository/chat/get_chat_history_with_notifications.py b/backend/repository/chat/get_chat_history_with_notifications.py new file mode 100644 index 000000000000..de2ec18df8ec --- /dev/null +++ b/backend/repository/chat/get_chat_history_with_notifications.py @@ -0,0 +1,57 @@ +from enum import Enum +from typing import List, Union +from uuid import UUID + +from models.notifications import Notification +from pydantic import BaseModel +from utils.parse_message_time import ( + parse_message_time, +) + +from repository.chat.get_chat_history import GetChatHistoryOutput, get_chat_history +from repository.notification.get_chat_notifications import ( + get_chat_notifications, +) + + +class ChatItemType(Enum): + MESSAGE = "MESSAGE" + NOTIFICATION = "NOTIFICATION" + + +class ChatItem(BaseModel): + item_type: ChatItemType + body: Union[GetChatHistoryOutput, Notification] + + +def merge_chat_history_and_notifications( + chat_history: List[GetChatHistoryOutput], notifications: List[Notification] +) -> List[ChatItem]: + chat_history_and_notifications = chat_history + notifications + + chat_history_and_notifications.sort( + key=lambda x: parse_message_time(x.message_time) + if isinstance(x, GetChatHistoryOutput) + else parse_message_time(x.datetime) + ) + + transformed_data = [] + for item in chat_history_and_notifications: + if isinstance(item, GetChatHistoryOutput): + item_type = ChatItemType.MESSAGE + body = item + else: + item_type = ChatItemType.NOTIFICATION + body = item + transformed_item = ChatItem(item_type=item_type, body=body) + transformed_data.append(transformed_item) + + return transformed_data + + +def get_chat_history_with_notifications( + chat_id: UUID, +) -> List[ChatItem]: + chat_history = get_chat_history(str(chat_id)) + chat_notifications = get_chat_notifications(chat_id) + return merge_chat_history_and_notifications(chat_history, chat_notifications) diff --git a/backend/repository/notification/get_chat_notifications.py b/backend/repository/notification/get_chat_notifications.py new file mode 100644 index 000000000000..c15ae01c63f1 --- /dev/null +++ b/backend/repository/notification/get_chat_notifications.py @@ -0,0 +1,14 @@ +from typing import List +from uuid import UUID + +from models.notifications import Notification +from models.settings import get_supabase_db + + +def get_chat_notifications(chat_id: UUID) -> List[Notification]: + """ + Get notifications by chat_id + """ + supabase_db = get_supabase_db() + + return supabase_db.get_notifications_by_chat_id(chat_id) diff --git a/backend/repository/notification/remove_chat_notifications.py b/backend/repository/notification/remove_chat_notifications.py new file mode 100644 index 000000000000..d7d3d449f4e4 --- /dev/null +++ b/backend/repository/notification/remove_chat_notifications.py @@ -0,0 +1,11 @@ +from uuid import UUID +from models.settings import get_supabase_db + + +def remove_chat_notifications(chat_id: UUID) -> None: + """ + Remove all notifications for a chat + """ + supabase_db = get_supabase_db() + + supabase_db.remove_notifications_by_chat_id(chat_id) diff --git a/backend/routes/chat_routes.py b/backend/routes/chat_routes.py index 3905d70bc0ec..25100110ec71 100644 --- a/backend/routes/chat_routes.py +++ b/backend/routes/chat_routes.py @@ -7,6 +7,9 @@ from auth import AuthBearer, get_current_user from fastapi import APIRouter, Depends, HTTPException, Query, Request from fastapi.responses import StreamingResponse +from repository.notification.remove_chat_notifications import ( + remove_chat_notifications, +) from llm.openai import OpenAIBrainPicking from llm.qa_headless import HeadlessQA from models import ( @@ -26,10 +29,13 @@ GetChatHistoryOutput, create_chat, get_chat_by_id, - get_chat_history, get_user_chats, update_chat, ) +from repository.chat.get_chat_history_with_notifications import ( + ChatItem, + get_chat_history_with_notifications, +) from repository.user_identity import get_user_identity chat_router = APIRouter() @@ -114,6 +120,8 @@ async def delete_chat(chat_id: UUID): Delete a specific chat by chat ID. """ supabase_db = get_supabase_db() + remove_chat_notifications(chat_id) + delete_chat_from_db(supabase_db=supabase_db, chat_id=chat_id) return {"message": f"{chat_id} has been deleted."} @@ -333,6 +341,6 @@ async def create_stream_question_handler( ) async def get_chat_history_handler( chat_id: UUID, -) -> List[GetChatHistoryOutput]: +) -> List[ChatItem]: # TODO: RBAC with current_user - return get_chat_history(str(chat_id)) + return get_chat_history_with_notifications(chat_id) diff --git a/backend/routes/crawl_routes.py b/backend/routes/crawl_routes.py index 6f962581bfa7..0cf6940dc911 100644 --- a/backend/routes/crawl_routes.py +++ b/backend/routes/crawl_routes.py @@ -33,6 +33,7 @@ async def crawl_endpoint( request: Request, crawl_website: CrawlWebsite, brain_id: UUID = Query(..., description="The ID of the brain"), + chat_id: UUID = Query(..., description="The ID of the chat"), enable_summarization: bool = False, current_user: UserIdentity = Depends(get_current_user), ): @@ -56,11 +57,15 @@ async def crawl_endpoint( "type": "error", } else: - crawl_notification = add_notification( - CreateNotificationProperties( - action="CRAWL", + crawl_notification = None + if chat_id: + crawl_notification = add_notification( + CreateNotificationProperties( + action="CRAWL", + chat_id=chat_id, + status=NotificationsStatusEnum.Pending, + ) ) - ) if not crawl_website.checkGithub(): ( file_path, @@ -92,10 +97,11 @@ async def crawl_endpoint( brain_id=brain_id, user_openai_api_key=request.headers.get("Openai-Api-Key", None), ) - update_notification_by_id( - crawl_notification.id, - NotificationUpdatableProperties( - status=NotificationsStatusEnum.Done, message=str(message) - ), - ) + if crawl_notification: + update_notification_by_id( + crawl_notification.id, + NotificationUpdatableProperties( + status=NotificationsStatusEnum.Done, message=str(message) + ), + ) return message diff --git a/backend/routes/upload_routes.py b/backend/routes/upload_routes.py index 1b3a49587fe2..2e757c1f37b0 100644 --- a/backend/routes/upload_routes.py +++ b/backend/routes/upload_routes.py @@ -36,6 +36,7 @@ async def upload_file( request: Request, uploadFile: UploadFile, brain_id: UUID = Query(..., description="The ID of the brain"), + chat_id: UUID = Query(..., description="The ID of the chat"), enable_summarization: bool = False, current_user: UserIdentity = Depends(get_current_user), ): @@ -71,11 +72,15 @@ async def upload_file( "type": "error", } else: - upload_notification = add_notification( - CreateNotificationProperties( - action="UPLOAD", + upload_notification = None + if chat_id: + upload_notification = add_notification( + CreateNotificationProperties( + action="UPLOAD", + chat_id=chat_id, + status=NotificationsStatusEnum.Pending, + ) ) - ) openai_api_key = request.headers.get("Openai-Api-Key", None) if openai_api_key is None: brain_details = get_brain_details(brain_id) @@ -91,11 +96,13 @@ async def upload_file( brain_id=brain_id, openai_api_key=openai_api_key, ) - update_notification_by_id( - upload_notification.id, - NotificationUpdatableProperties( - status=NotificationsStatusEnum.Done, message=str(message) - ), - ) + + if upload_notification: + update_notification_by_id( + upload_notification.id, + NotificationUpdatableProperties( + status=NotificationsStatusEnum.Done, message=str(message) + ), + ) return message diff --git a/backend/utils/parse_message_time.py b/backend/utils/parse_message_time.py new file mode 100644 index 000000000000..94387b0acb9c --- /dev/null +++ b/backend/utils/parse_message_time.py @@ -0,0 +1,5 @@ +from datetime import datetime + + +def parse_message_time(message_time_str): + return datetime.strptime(message_time_str, "%Y-%m-%dT%H:%M:%S.%f") diff --git a/frontend/app/chat/[chatId]/components/ActionsBar/hooks/useKnowledgeUploader.ts b/frontend/app/chat/[chatId]/components/ActionsBar/hooks/useKnowledgeUploader.ts index 54b85eeb0384..b39e04247adf 100644 --- a/frontend/app/chat/[chatId]/components/ActionsBar/hooks/useKnowledgeUploader.ts +++ b/frontend/app/chat/[chatId]/components/ActionsBar/hooks/useKnowledgeUploader.ts @@ -1,9 +1,11 @@ /* eslint-disable max-lines */ import axios from "axios"; import { UUID } from "crypto"; +import { useParams } from "next/navigation"; import { useCallback, useState } from "react"; import { useTranslation } from "react-i18next"; +import { useChatApi } from "@/lib/api/chat/useChatApi"; import { useCrawlApi } from "@/lib/api/crawl/useCrawlApi"; import { useUploadApi } from "@/lib/api/upload/useUploadApi"; import { useBrainContext } from "@/lib/context/BrainProvider/hooks/useBrainContext"; @@ -18,6 +20,10 @@ export const useKnowledgeUploader = () => { const { uploadFile } = useUploadApi(); const { t } = useTranslation(["upload"]); const { crawlWebsiteUrl } = useCrawlApi(); + const { createChat } = useChatApi(); + + const params = useParams(); + const chatId = params?.chatId as UUID | undefined; const { currentBrainId } = useBrainContext(); const addContent = (content: FeedItemType) => { @@ -28,7 +34,7 @@ export const useKnowledgeUploader = () => { }; const crawlWebsiteHandler = useCallback( - async (url: string, brainId: UUID) => { + async (url: string, brainId: UUID, chat_id: UUID) => { // Configure parameters const config = { url: url, @@ -42,6 +48,7 @@ export const useKnowledgeUploader = () => { await crawlWebsiteUrl({ brainId, config, + chat_id, }); } catch (error: unknown) { publish({ @@ -56,13 +63,14 @@ export const useKnowledgeUploader = () => { ); const uploadFileHandler = useCallback( - async (file: File, brainId: UUID) => { + async (file: File, brainId: UUID, chat_id: UUID) => { const formData = new FormData(); formData.append("uploadFile", file); try { await uploadFile({ - brainId: brainId, + brainId, formData, + chat_id, }); } catch (e: unknown) { if (axios.isAxiosError(e) && e.response?.status === 403) { @@ -104,12 +112,14 @@ export const useKnowledgeUploader = () => { return; } + try { + const currentChatId = chatId ?? (await createChat("New Chat")).chat_id; const uploadPromises = files.map((file) => - uploadFileHandler(file, currentBrainId) + uploadFileHandler(file, currentBrainId, currentChatId) ); const crawlPromises = urls.map((url) => - crawlWebsiteHandler(url, currentBrainId) + crawlWebsiteHandler(url, currentBrainId, currentChatId) ); await Promise.all([...uploadPromises, ...crawlPromises]); diff --git a/frontend/app/chat/[chatId]/hooks/useSelectedChatPage.ts b/frontend/app/chat/[chatId]/hooks/useSelectedChatPage.ts index c4a33ffb216a..152ac122cf66 100644 --- a/frontend/app/chat/[chatId]/hooks/useSelectedChatPage.ts +++ b/frontend/app/chat/[chatId]/hooks/useSelectedChatPage.ts @@ -4,6 +4,8 @@ import { useEffect } from "react"; import { useChatApi } from "@/lib/api/chat/useChatApi"; import { useChatContext } from "@/lib/context"; +import { getMessagesFromChatHistory } from "../utils/getMessagesFromChatHistory"; + // eslint-disable-next-line @typescript-eslint/explicit-module-boundary-types export const useSelectedChatPage = () => { const { setHistory } = useChatContext(); @@ -23,7 +25,7 @@ export const useSelectedChatPage = () => { const chatHistory = await getHistory(chatId); if (chatHistory.length > 0) { - setHistory(chatHistory); + setHistory(getMessagesFromChatHistory(chatHistory)); } }; void fetchHistory(); diff --git a/frontend/app/chat/[chatId]/types/index.ts b/frontend/app/chat/[chatId]/types/index.ts index 382b834fbe4b..faefcc6ef908 100644 --- a/frontend/app/chat/[chatId]/types/index.ts +++ b/frontend/app/chat/[chatId]/types/index.ts @@ -18,6 +18,22 @@ export type ChatHistory = { brain_name?: string; }; +type HistoryItemType = "MESSAGE" | "NOTIFICATION"; + +type Notification = { + id: string; + datetime: string; + chat_id?: string | null; + message?: string | null; + action: string; + status: string; +}; + +export type ChatItem = { + item_type: HistoryItemType; + body: ChatHistory | Notification; +}; + export type ChatEntity = { chat_id: UUID; user_id: string; diff --git a/frontend/app/chat/[chatId]/utils/getMessagesFromChatHistory.ts b/frontend/app/chat/[chatId]/utils/getMessagesFromChatHistory.ts new file mode 100644 index 000000000000..c5ae1926f2d5 --- /dev/null +++ b/frontend/app/chat/[chatId]/utils/getMessagesFromChatHistory.ts @@ -0,0 +1,11 @@ +import { ChatHistory, ChatItem } from "../types"; + +export const getMessagesFromChatHistory = ( + chatHistory: ChatItem[] +): ChatHistory[] => { + const messages = chatHistory + .filter((item) => item.item_type === "MESSAGE") + .map((item) => item.body as ChatHistory); + + return messages; +}; diff --git a/frontend/lib/api/chat/chat.ts b/frontend/lib/api/chat/chat.ts index bf58f0defeea..226fbcdfa868 100644 --- a/frontend/lib/api/chat/chat.ts +++ b/frontend/lib/api/chat/chat.ts @@ -3,6 +3,7 @@ import { AxiosInstance } from "axios"; import { ChatEntity, ChatHistory, + ChatItem, ChatQuestion, } from "@/app/chat/[chatId]/types"; @@ -55,8 +56,8 @@ export const addQuestion = async ( export const getHistory = async ( chatId: string, axiosInstance: AxiosInstance -): Promise => - (await axiosInstance.get(`/chat/${chatId}/history`)).data; +): Promise => + (await axiosInstance.get(`/chat/${chatId}/history`)).data; export type ChatUpdatableProperties = { chat_name?: string; diff --git a/frontend/lib/api/crawl/__tests__/useCrawlApi.test.ts b/frontend/lib/api/crawl/__tests__/useCrawlApi.test.ts index 1938042dafa7..24625b135608 100644 --- a/frontend/lib/api/crawl/__tests__/useCrawlApi.test.ts +++ b/frontend/lib/api/crawl/__tests__/useCrawlApi.test.ts @@ -24,6 +24,7 @@ describe("useCrawlApi", () => { } = renderHook(() => useCrawlApi()); const crawlInputProps: CrawlInputProps = { brainId: "e7001ccd-6d90-4eab-8c50-2f23d39441e4", + chat_id: "e7001ccd-6d90-4eab-8c50-2f23d39441es", config: { url: "https://en.wikipedia.org/wiki/Mali", js: false, @@ -36,7 +37,9 @@ describe("useCrawlApi", () => { expect(axiosPostMock).toHaveBeenCalledTimes(1); expect(axiosPostMock).toHaveBeenCalledWith( - `/crawl?brain_id=${crawlInputProps.brainId}`, + `/crawl?brain_id=${crawlInputProps.brainId}&chat_id=${ + crawlInputProps.chat_id ?? "" + }`, crawlInputProps.config ); }); diff --git a/frontend/lib/api/crawl/crawl.ts b/frontend/lib/api/crawl/crawl.ts index e57514401fe8..7263f6c5573b 100644 --- a/frontend/lib/api/crawl/crawl.ts +++ b/frontend/lib/api/crawl/crawl.ts @@ -5,6 +5,7 @@ import { ToastData } from "@/lib/components/ui/Toast/domain/types"; export type CrawlInputProps = { brainId: UUID; + chat_id?: UUID; config: { url: string; js: boolean; @@ -22,4 +23,7 @@ export const crawlWebsiteUrl = async ( props: CrawlInputProps, axiosInstance: AxiosInstance ): Promise => - axiosInstance.post(`/crawl?brain_id=${props.brainId}`, props.config); + axiosInstance.post( + `/crawl?brain_id=${props.brainId}&chat_id=${props.chat_id ?? ""}`, + props.config + ); diff --git a/frontend/lib/api/upload/upload.ts b/frontend/lib/api/upload/upload.ts index d2e27e13ec70..7d2a48d76caf 100644 --- a/frontend/lib/api/upload/upload.ts +++ b/frontend/lib/api/upload/upload.ts @@ -10,10 +10,14 @@ export type UploadResponse = { export type UploadInputProps = { brainId: UUID; formData: FormData; + chat_id?: UUID; }; export const uploadFile = async ( props: UploadInputProps, axiosInstance: AxiosInstance ): Promise => - axiosInstance.post(`/upload?brain_id=${props.brainId}`, props.formData); + axiosInstance.post( + `/upload?brain_id=${props.brainId}&chat_id=${props.chat_id ?? ""}`, + props.formData + );