diff --git a/.gitignore b/.gitignore index 7159f88b2c70..92ef296891f1 100644 --- a/.gitignore +++ b/.gitignore @@ -59,3 +59,4 @@ bun.lockb sitemap*.xml robots.txt +*.patch diff --git a/package.json b/package.json index fdbd85b2142b..7ff95b559b48 100644 --- a/package.json +++ b/package.json @@ -148,7 +148,7 @@ "uuid": "^9", "yaml": "^2", "zod": "^3", - "zustand": "^4.4", + "zustand": "^4.5.2", "zustand-utils": "^1.3.2" }, "devDependencies": { diff --git a/src/app/api/plugin/store/route.ts b/src/app/api/plugin/store/route.ts index afe198ba99be..a203a2467f5b 100644 --- a/src/app/api/plugin/store/route.ts +++ b/src/app/api/plugin/store/route.ts @@ -11,7 +11,7 @@ export const GET = async (req: Request) => { let res: Response; - res = await fetch(pluginStore.getPluginIndexUrl(locale as any)); + res = await fetch(pluginStore.getPluginIndexUrl(locale as any), { next: { revalidate: 3600 } }); if (res.status === 404) { res = await fetch(pluginStore.getPluginIndexUrl(DEFAULT_LANG)); diff --git a/src/database/models/message.ts b/src/database/models/message.ts index 6d9ec5b7f4c4..143c71a3e247 100644 --- a/src/database/models/message.ts +++ b/src/database/models/message.ts @@ -26,19 +26,8 @@ class _MessageModel extends BaseModel { constructor() { super('messages', DB_MessageSchema); } - async create(data: CreateMessageParams) { - const id = nanoid(); - - const messageData: DB_Message = this.mapChatMessageToDBMessage(data as ChatMessage); - - return this._add(messageData, id); - } - - async batchCreate(messages: ChatMessage[]) { - const data: DB_Message[] = messages.map((m) => this.mapChatMessageToDBMessage(m)); - return this._batchAdd(data); - } + // **************** Query *************** // async query({ sessionId, @@ -91,45 +80,79 @@ class _MessageModel extends BaseModel { return this.table.get(id); } - async delete(id: string) { - return this.table.delete(id); + async queryAll() { + const data: DBModel[] = await this.table.orderBy('updatedAt').toArray(); + + return data.map((element) => this.mapToChatMessage(element)); } - async clearTable() { - return this.table.clear(); + async queryBySessionId(sessionId: string) { + return this.table.where('sessionId').equals(sessionId).toArray(); } - async update(id: string, data: DeepPartial) { - return super._update(id, data); + queryByTopicId = async (topicId: string) => { + const dbMessages = await this.table.where('topicId').equals(topicId).toArray(); + + return dbMessages.map((message) => this.mapToChatMessage(message)); + }; + + async count() { + return this.table.count(); } - async updatePluginState(id: string, key: string, value: any) { - const item = await this.findById(id); + // **************** Create *************** // - return this.update(id, { pluginState: { ...item.pluginState, [key]: value } }); + async create(data: CreateMessageParams) { + const id = nanoid(); + + const messageData: DB_Message = this.mapChatMessageToDBMessage(data as ChatMessage); + + return this._add(messageData, id); } - /** - * Batch updates multiple fields of the specified messages. - * - * @param {string[]} messageIds - The identifiers of the messages to be updated. - * @param {Partial} updateFields - An object containing the fields to update and their new values. - * @returns {Promise} - The number of updated messages. - */ - async batchUpdate(messageIds: string[], updateFields: Partial): Promise { - // Retrieve the messages by their IDs - const messagesToUpdate = await this.table.where(':id').anyOf(messageIds).toArray(); + async batchCreate(messages: ChatMessage[]) { + const data: DB_Message[] = messages.map((m) => this.mapChatMessageToDBMessage(m)); - // Update the specified fields of each message - const updatedMessages = messagesToUpdate.map((message) => ({ - ...message, - ...updateFields, - })); + return this._batchAdd(data); + } - // Use the bulkPut method to update the messages in bulk - await this.table.bulkPut(updatedMessages); + async duplicateMessages(messages: ChatMessage[]): Promise { + const duplicatedMessages = await this.createDuplicateMessages(messages); + // 批量添加复制后的消息到数据库 + await this.batchCreate(duplicatedMessages); + return duplicatedMessages; + } - return updatedMessages.length; + async createDuplicateMessages(messages: ChatMessage[]): Promise { + // 创建一个映射来存储原始消息ID和复制消息ID之间的关系 + const idMapping = new Map(); + + // 首先复制所有消息,并为每个复制的消息生成新的ID + const duplicatedMessages = messages.map((originalMessage) => { + const newId = nanoid(); + idMapping.set(originalMessage.id, newId); + + return { ...originalMessage, id: newId }; + }); + + // 更新 parentId 为复制后的新ID + for (const duplicatedMessage of duplicatedMessages) { + if (duplicatedMessage.parentId && idMapping.has(duplicatedMessage.parentId)) { + duplicatedMessage.parentId = idMapping.get(duplicatedMessage.parentId); + } + } + + return duplicatedMessages; + } + + // **************** Delete *************** // + + async delete(id: string) { + return this.table.delete(id); + } + + async clearTable() { + return this.table.clear(); } /** @@ -158,55 +181,43 @@ class _MessageModel extends BaseModel { return this.table.bulkDelete(messageIds); } - async queryAll() { - const data: DBModel[] = await this.table.orderBy('updatedAt').toArray(); - - return data.map((element) => this.mapToChatMessage(element)); - } - - async count() { - return this.table.count(); - } + // **************** Update *************** // - async queryBySessionId(sessionId: string) { - return this.table.where('sessionId').equals(sessionId).toArray(); + async update(id: string, data: DeepPartial) { + return this._update(id, data); } - queryByTopicId = async (topicId: string) => { - const dbMessages = await this.table.where('topicId').equals(topicId).toArray(); - - return dbMessages.map((message) => this.mapToChatMessage(message)); - }; + async updatePluginState(id: string, key: string, value: any) { + const item = await this.findById(id); - async duplicateMessages(messages: ChatMessage[]): Promise { - const duplicatedMessages = await this.createDuplicateMessages(messages); - // 批量添加复制后的消息到数据库 - await this.batchCreate(duplicatedMessages); - return duplicatedMessages; + return this.update(id, { pluginState: { ...item.pluginState, [key]: value } }); } - async createDuplicateMessages(messages: ChatMessage[]): Promise { - // 创建一个映射来存储原始消息ID和复制消息ID之间的关系 - const idMapping = new Map(); - - // 首先复制所有消息,并为每个复制的消息生成新的ID - const duplicatedMessages = messages.map((originalMessage) => { - const newId = nanoid(); - idMapping.set(originalMessage.id, newId); + /** + * Batch updates multiple fields of the specified messages. + * + * @param {string[]} messageIds - The identifiers of the messages to be updated. + * @param {Partial} updateFields - An object containing the fields to update and their new values. + * @returns {Promise} - The number of updated messages. + */ + async batchUpdate(messageIds: string[], updateFields: Partial): Promise { + // Retrieve the messages by their IDs + const messagesToUpdate = await this.table.where(':id').anyOf(messageIds).toArray(); - return { ...originalMessage, id: newId }; - }); + // Update the specified fields of each message + const updatedMessages = messagesToUpdate.map((message) => ({ + ...message, + ...updateFields, + })); - // 更新 parentId 为复制后的新ID - for (const duplicatedMessage of duplicatedMessages) { - if (duplicatedMessage.parentId && idMapping.has(duplicatedMessage.parentId)) { - duplicatedMessage.parentId = idMapping.get(duplicatedMessage.parentId); - } - } + // Use the bulkPut method to update the messages in bulk + await this.table.bulkPut(updatedMessages); - return duplicatedMessages; + return updatedMessages.length; } + // **************** Helper *************** // + private mapChatMessageToDBMessage(message: ChatMessage): DB_Message { const { extra, ...messageData } = message; diff --git a/src/database/models/plugin.ts b/src/database/models/plugin.ts index dd78db1bbebf..a12042ec893e 100644 --- a/src/database/models/plugin.ts +++ b/src/database/models/plugin.ts @@ -7,11 +7,14 @@ class _PluginModel extends BaseModel { constructor() { super('plugins', DB_PluginSchema); } + // **************** Query *************** // getList = async (): Promise => { return this.table.toArray(); }; + // **************** Create *************** // + create = async (plugin: DB_Plugin) => { const old = await this.table.get(plugin.identifier); @@ -21,18 +24,20 @@ class _PluginModel extends BaseModel { batchCreate = async (plugins: DB_Plugin[]) => { return this._batchAdd(plugins); }; + // **************** Delete *************** // delete(id: string) { return this.table.delete(id); } + clear() { + return this.table.clear(); + } + + // **************** Update *************** // update: (id: string, value: Partial) => Promise = async (id, value) => { return this.table.update(id, value); }; - - clear() { - return this.table.clear(); - } } export const PluginModel = new _PluginModel(); diff --git a/src/database/models/session.ts b/src/database/models/session.ts index 91abda40e61f..43ba801731c4 100644 --- a/src/database/models/session.ts +++ b/src/database/models/session.ts @@ -21,29 +21,7 @@ class _SessionModel extends BaseModel { super('sessions', DB_SessionSchema); } - async create(type: 'agent' | 'group', defaultValue: Partial, id = uuid()) { - const data = merge(DEFAULT_AGENT_LOBE_SESSION, { type, ...defaultValue }); - const dataDB = this.mapToDB_Session(data); - return this._add(dataDB, id); - } - - async batchCreate(sessions: LobeAgentSession[]) { - const DB_Sessions = await Promise.all( - sessions.map(async (s) => { - if (s.group && s.group !== SessionDefaultGroup.Default) { - // Check if the group exists in the SessionGroup table - const groupExists = await SessionGroupModel.findById(s.group); - // If the group does not exist, set it to default - if (!groupExists) { - s.group = SessionDefaultGroup.Default; - } - } - return this.mapToDB_Session(s); - }), - ); - - return this._batchAdd(DB_Sessions, { idGenerator: uuid }); - } + // **************** Query *************** // async query({ pageSize = 9999, @@ -103,59 +81,6 @@ class _SessionModel extends BaseModel { return Object.fromEntries(groupItems); } - async update(id: string, data: Partial) { - return super._update(id, data); - } - - async updatePinned(id: string, pinned: boolean) { - return this.update(id, { pinned: pinned ? 1 : 0 }); - } - - async updateConfig(id: string, data: DeepPartial) { - const session = await this.findById(id); - if (!session) return; - - const config = merge(session.config, data); - - return this.update(id, { config }); - } - - /** - * Delete a session , also delete all messages and topic associated with it. - */ - async delete(id: string) { - return this.db.transaction('rw', [this.table, this.db.topics, this.db.messages], async () => { - // Delete all topics associated with the session - const topics = await this.db.topics.where('sessionId').equals(id).toArray(); - const topicIds = topics.map((topic) => topic.id); - if (topicIds.length > 0) { - await this.db.topics.bulkDelete(topicIds); - } - - // Delete all messages associated with the session - const messages = await this.db.messages.where('sessionId').equals(id).toArray(); - const messageIds = messages.map((message) => message.id); - if (messageIds.length > 0) { - await this.db.messages.bulkDelete(messageIds); - } - - // Finally, delete the session itself - await this.table.delete(id); - }); - } - - async clearTable() { - return this.table.clear(); - } - - async findById(id: string): Promise> { - return this.table.get(id); - } - - async isEmpty() { - return (await this.table.count()) === 0; - } - /** * Query sessions by keyword in title, description, content, or translated content * @param keyword The keyword to search for @@ -225,6 +150,50 @@ class _SessionModel extends BaseModel { return this.mapToAgentSessions(items); } + async getPinnedSessions(): Promise { + const items: DBModel[] = await this.table + .where('pinned') + .equals(1) + .reverse() + .sortBy('updatedAt'); + + return this.mapToAgentSessions(items); + } + + async findById(id: string): Promise> { + return this.table.get(id); + } + + async isEmpty() { + return (await this.table.count()) === 0; + } + + // **************** Create *************** // + + async create(type: 'agent' | 'group', defaultValue: Partial, id = uuid()) { + const data = merge(DEFAULT_AGENT_LOBE_SESSION, { type, ...defaultValue }); + const dataDB = this.mapToDB_Session(data); + return this._add(dataDB, id); + } + + async batchCreate(sessions: LobeAgentSession[]) { + const DB_Sessions = await Promise.all( + sessions.map(async (s) => { + if (s.group && s.group !== SessionDefaultGroup.Default) { + // Check if the group exists in the SessionGroup table + const groupExists = await SessionGroupModel.findById(s.group); + // If the group does not exist, set it to default + if (!groupExists) { + s.group = SessionDefaultGroup.Default; + } + } + return this.mapToDB_Session(s); + }), + ); + + return this._batchAdd(DB_Sessions, { idGenerator: uuid }); + } + async duplicate(id: string, newTitle?: string) { const session = await this.findById(id); if (!session) return; @@ -234,16 +203,57 @@ class _SessionModel extends BaseModel { return this._add(newSession, uuid()); } - async getPinnedSessions(): Promise { - const items: DBModel[] = await this.table - .where('pinned') - .equals(1) - .reverse() - .sortBy('updatedAt'); + // **************** Delete *************** // - return this.mapToAgentSessions(items); + /** + * Delete a session , also delete all messages and topic associated with it. + */ + async delete(id: string) { + return this.db.transaction('rw', [this.table, this.db.topics, this.db.messages], async () => { + // Delete all topics associated with the session + const topics = await this.db.topics.where('sessionId').equals(id).toArray(); + const topicIds = topics.map((topic) => topic.id); + if (topicIds.length > 0) { + await this.db.topics.bulkDelete(topicIds); + } + + // Delete all messages associated with the session + const messages = await this.db.messages.where('sessionId').equals(id).toArray(); + const messageIds = messages.map((message) => message.id); + if (messageIds.length > 0) { + await this.db.messages.bulkDelete(messageIds); + } + + // Finally, delete the session itself + await this.table.delete(id); + }); + } + + async clearTable() { + return this.table.clear(); + } + + // **************** Update *************** // + + async update(id: string, data: Partial) { + return this._update(id, data); + } + + async updatePinned(id: string, pinned: boolean) { + return this.update(id, { pinned: pinned ? 1 : 0 }); } + async updateConfig(id: string, data: DeepPartial) { + const session = await this.findById(id); + if (!session) return; + + const config = merge(session.config, data); + + return this.update(id, { config }); + } + + // **************** Helper *************** // + private mapToDB_Session(session: LobeAgentSession): DBModel { return { ...session, diff --git a/src/database/models/sessionGroup.ts b/src/database/models/sessionGroup.ts index cff661c4f1b1..01d3eeddbe1a 100644 --- a/src/database/models/sessionGroup.ts +++ b/src/database/models/sessionGroup.ts @@ -8,32 +8,7 @@ class _SessionGroupModel extends BaseModel { super('sessionGroups', DB_SessionGroupSchema); } - async create(name: string, sort?: number, id = nanoid()) { - return this._add({ name, sort }, id); - } - async batchCreate(groups: SessionGroups) { - return this._batchAdd(groups, { idGenerator: nanoid }); - } - - async findById(id: string): Promise { - return this.table.get(id); - } - - async update(id: string, data: Partial) { - return super._update(id, data); - } - - async delete(id: string, removeGroupItem: boolean = false) { - this.db.sessions.toCollection().modify((session) => { - // update all session associated with the sessionGroup to default - if (session.group === id) session.group = 'default'; - }); - if (!removeGroupItem) { - return this.table.delete(id); - } else { - return this.db.sessions.where('group').equals(id).delete(); - } - } + // **************** Query *************** // async query(): Promise { const allGroups = await this.table.toArray(); @@ -60,6 +35,43 @@ class _SessionGroupModel extends BaseModel { }); } + async findById(id: string): Promise { + return this.table.get(id); + } + + // **************** Create *************** // + + async create(name: string, sort?: number, id = nanoid()) { + return this._add({ name, sort }, id); + } + + async batchCreate(groups: SessionGroups) { + return this._batchAdd(groups, { idGenerator: nanoid }); + } + + // **************** Delete *************** // + async delete(id: string, removeGroupItem: boolean = false) { + this.db.sessions.toCollection().modify((session) => { + // update all session associated with the sessionGroup to default + if (session.group === id) session.group = 'default'; + }); + if (!removeGroupItem) { + return this.table.delete(id); + } else { + return this.db.sessions.where('group').equals(id).delete(); + } + } + + async clear() { + this.table.clear(); + } + + // **************** Update *************** // + + async update(id: string, data: Partial) { + return super._update(id, data); + } + async updateOrder(sortMap: { id: string; sort: number }[]) { return this.db.transaction('rw', this.table, async () => { for (const { id, sort } of sortMap) { @@ -67,10 +79,6 @@ class _SessionGroupModel extends BaseModel { } }); } - - async clear() { - this.table.clear(); - } } export const SessionGroupModel = new _SessionGroupModel(); diff --git a/src/database/models/topic.ts b/src/database/models/topic.ts index 72a065ab66c1..addb36658bc1 100644 --- a/src/database/models/topic.ts +++ b/src/database/models/topic.ts @@ -23,19 +23,7 @@ class _TopicModel extends BaseModel { super('topics', DB_TopicSchema); } - async create({ title, favorite, sessionId, messages }: CreateTopicParams, id = nanoid()) { - const topic = await this._add({ favorite: favorite ? 1 : 0, sessionId, title: title }, id); - - // add topicId to these messages - if (messages) { - await this.db.messages.where('id').anyOf(messages).modify({ topicId: topic.id }); - } - return topic; - } - - async batchCreate(topics: CreateTopicParams[]) { - return this._batchAdd(topics.map((t) => ({ ...t, favorite: t.favorite ? 1 : 0 }))); - } + // **************** Query *************** // async query({ pageSize = 9999, current = 0, sessionId }: QueryTopicParams): Promise { const offset = current * pageSize; @@ -58,90 +46,6 @@ class _TopicModel extends BaseModel { return pagedTopics.map((i) => this.mapToChatTopic(i)); } - async findBySessionId(sessionId: string) { - return this.table.where({ sessionId }).toArray(); - } - - async findById(id: string): Promise> { - return this.table.get(id); - } - - /** - * Deletes a topic and all messages associated with it. - */ - async delete(id: string) { - return this.db.transaction('rw', [this.table, this.db.messages], async () => { - // Delete all messages associated with the topic - const messages = await this.db.messages.where('topicId').equals(id).toArray(); - - if (messages.length > 0) { - const messageIds = messages.map((msg) => msg.id); - await this.db.messages.bulkDelete(messageIds); - } - - await this.table.delete(id); - }); - } - - /** - * Deletes multiple topic based on the sessionId. - * - * @param {string} sessionId - The identifier of the assistant associated with the messages. - * @returns {Promise} - */ - async batchDeleteBySessionId(sessionId: string): Promise { - // use sessionId as the filter criteria in the query. - const query = this.table.where('sessionId').equals(sessionId); - - // Retrieve a collection of message IDs that satisfy the criteria - const topicIds = await query.primaryKeys(); - - // Use the bulkDelete method to delete all selected messages in bulk - return this.table.bulkDelete(topicIds); - } - - async clearTable() { - return this.table.clear(); - } - - async update(id: string, data: Partial) { - return super._update(id, { ...data, updatedAt: Date.now() }); - } - - async toggleFavorite(id: string, newState?: boolean) { - const topic = await this.findById(id); - if (!topic) { - throw new Error(`Topic with id ${id} not found`); - } - - // Toggle the 'favorite' status - const nextState = typeof newState !== 'undefined' ? newState : !topic.favorite; - - await this.update(id, { favorite: nextState ? 1 : 0 }); - - return nextState; - } - - /** - * Deletes multiple topics and all messages associated with them in a transaction. - */ - async batchDelete(topicIds: string[]) { - return this.db.transaction('rw', [this.table, this.db.messages], async () => { - // Iterate over each topicId and delete related messages, then delete the topic itself - for (const topicId of topicIds) { - // Delete all messages associated with the topic - const messages = await this.db.messages.where('topicId').equals(topicId).toArray(); - if (messages.length > 0) { - const messageIds = messages.map((msg) => msg.id); - await this.db.messages.bulkDelete(messageIds); - } - - // Delete the topic - await this.table.delete(topicId); - } - }); - } - queryAll() { return this.table.orderBy('updatedAt').toArray(); } @@ -201,6 +105,29 @@ class _TopicModel extends BaseModel { return uniqueTopics.map((i) => ({ ...i, favorite: !!i.favorite })); } + async findBySessionId(sessionId: string) { + return this.table.where({ sessionId }).toArray(); + } + + async findById(id: string): Promise> { + return this.table.get(id); + } + + // **************** Create *************** // + + async create({ title, favorite, sessionId, messages }: CreateTopicParams, id = nanoid()) { + const topic = await this._add({ favorite: favorite ? 1 : 0, sessionId, title: title }, id); + + // add topicId to these messages + if (messages) { + await this.db.messages.where('id').anyOf(messages).modify({ topicId: topic.id }); + } + return topic; + } + async batchCreate(topics: CreateTopicParams[]) { + return this._batchAdd(topics.map((t) => ({ ...t, favorite: t.favorite ? 1 : 0 }))); + } + async duplicateTopic(topicId: string, newTitle?: string) { return this.db.transaction('rw', this.db.topics, this.db.messages, async () => { // Step 1: get DB_Topic @@ -226,6 +153,86 @@ class _TopicModel extends BaseModel { }); } + // **************** Delete *************** // + + /** + * Deletes a topic and all messages associated with it. + */ + async delete(id: string) { + return this.db.transaction('rw', [this.table, this.db.messages], async () => { + // Delete all messages associated with the topic + const messages = await this.db.messages.where('topicId').equals(id).toArray(); + + if (messages.length > 0) { + const messageIds = messages.map((msg) => msg.id); + await this.db.messages.bulkDelete(messageIds); + } + + await this.table.delete(id); + }); + } + + /** + * Deletes multiple topic based on the sessionId. + * + * @param {string} sessionId - The identifier of the assistant associated with the messages. + * @returns {Promise} + */ + async batchDeleteBySessionId(sessionId: string): Promise { + // use sessionId as the filter criteria in the query. + const query = this.table.where('sessionId').equals(sessionId); + + // Retrieve a collection of message IDs that satisfy the criteria + const topicIds = await query.primaryKeys(); + + // Use the bulkDelete method to delete all selected messages in bulk + return this.table.bulkDelete(topicIds); + } + /** + * Deletes multiple topics and all messages associated with them in a transaction. + */ + async batchDelete(topicIds: string[]) { + return this.db.transaction('rw', [this.table, this.db.messages], async () => { + // Iterate over each topicId and delete related messages, then delete the topic itself + for (const topicId of topicIds) { + // Delete all messages associated with the topic + const messages = await this.db.messages.where('topicId').equals(topicId).toArray(); + if (messages.length > 0) { + const messageIds = messages.map((msg) => msg.id); + await this.db.messages.bulkDelete(messageIds); + } + + // Delete the topic + await this.table.delete(topicId); + } + }); + } + + async clearTable() { + return this.table.clear(); + } + + // **************** Update *************** // + async update(id: string, data: Partial) { + return this._update(id, data); + } + + async toggleFavorite(id: string, newState?: boolean) { + const topic = await this.findById(id); + if (!topic) { + throw new Error(`Topic with id ${id} not found`); + } + + // Toggle the 'favorite' status + const nextState = typeof newState !== 'undefined' ? newState : !topic.favorite; + + await this.update(id, { favorite: nextState ? 1 : 0 }); + + return nextState; + } + + // **************** Helper *************** // + private mapToChatTopic = (dbTopic: DBModel): ChatTopic => ({ ...dbTopic, favorite: !!dbTopic.favorite, diff --git a/src/database/models/user.ts b/src/database/models/user.ts index 6de2e6aff5fd..3a7ae930f007 100644 --- a/src/database/models/user.ts +++ b/src/database/models/user.ts @@ -10,6 +10,7 @@ class _UserModel extends BaseModel { constructor() { super('users', DB_UserSchema); } + // **************** Query *************** // getUser = async (): Promise => { const noUser = !(await this.table.count()); @@ -21,18 +22,20 @@ class _UserModel extends BaseModel { return list[0]; }; + // **************** Create *************** // + create = async (user: DB_User) => { return this.table.put(user); }; - private update = async (id: number, value: DeepPartial) => { - return this.table.update(id, value); - }; + // **************** Delete *************** // clear() { return this.table.clear(); } + // **************** Update *************** // + async updateSettings(settings: DeepPartial) { const user = await this.getUser(); @@ -50,6 +53,12 @@ class _UserModel extends BaseModel { return this.update(user.id, { avatar }); } + + // **************** Helper *************** // + + private update = async (id: number, value: DeepPartial) => { + return this.table.update(id, value); + }; } export const UserModel = new _UserModel(); diff --git a/src/libs/swr/index.ts b/src/libs/swr/index.ts new file mode 100644 index 000000000000..8b4d643fa058 --- /dev/null +++ b/src/libs/swr/index.ts @@ -0,0 +1,18 @@ +import useSWR, { SWRHook } from 'swr'; + +/** + * 这一类请求方法是比较「死」的请求模式,只会在第一次请求时触发。不会自动刷新,刷新需要搭配 refreshXXX 这样的方法实现, + * 适用于 messages、topics、sessions 等由用户在客户端交互产生的数据。 + */ +// @ts-ignore +export const useClientDataSWR: SWRHook = (key, fetch, config) => + useSWR(key, fetch, { + // default is 2000ms ,it makes the user's quick switch don't work correctly. + // Cause issue like this: https://github.com/lobehub/lobe-chat/issues/532 + // we need to set it to 0. + dedupingInterval: 0, + refreshWhenOffline: false, + revalidateOnFocus: false, + revalidateOnReconnect: false, + ...config, + }); diff --git a/src/store/chat/slices/message/action.test.ts b/src/store/chat/slices/message/action.test.ts index 8ed6226735ee..e85da112ac88 100644 --- a/src/store/chat/slices/message/action.test.ts +++ b/src/store/chat/slices/message/action.test.ts @@ -559,7 +559,7 @@ describe('chatMessage actions', () => { }); // 确保 mutate 调用了正确的参数 - expect(mutate).toHaveBeenCalledWith([activeId, activeTopicId]); + expect(mutate).toHaveBeenCalledWith(['SWR_USE_FETCH_MESSAGES', activeId, activeTopicId]); }); it('should handle errors during refreshing messages', async () => { useChatStore.setState({ refreshMessages: realRefreshMessages }); diff --git a/src/store/chat/slices/message/action.ts b/src/store/chat/slices/message/action.ts index 9fe71a6146d9..ccba6299a445 100644 --- a/src/store/chat/slices/message/action.ts +++ b/src/store/chat/slices/message/action.ts @@ -2,12 +2,13 @@ // Disable the auto sort key eslint rule to make the code more logic and readable import { copyToClipboard } from '@bentwnghk/ui'; import { template } from 'lodash-es'; -import useSWR, { SWRResponse, mutate } from 'swr'; +import { SWRResponse, mutate } from 'swr'; import { StateCreator } from 'zustand/vanilla'; import { LOADING_FLAT, isFunctionMessageAtStart, testFunctionMessageAtEnd } from '@/const/message'; import { TraceEventType, TraceNameMap } from '@/const/trace'; import { CreateMessageParams } from '@/database/models/message'; +import { useClientDataSWR } from '@/libs/swr'; import { chatService } from '@/services/chat'; import { messageService } from '@/services/message'; import { topicService } from '@/services/topic'; @@ -25,6 +26,8 @@ import { MessageDispatch, messagesReducer } from './reducer'; const n = setNamespace('message'); +const SWR_USE_FETCH_MESSAGES = 'SWR_USE_FETCH_MESSAGES'; + interface SendMessageParams { message: string; files?: { id: string; url: string }[]; @@ -274,9 +277,9 @@ export const chatMessage: StateCreator< await get().internalUpdateMessageContent(id, content); }, useFetchMessages: (sessionId, activeTopicId) => - useSWR( - [sessionId, activeTopicId], - async ([sessionId, topicId]: [string, string | undefined]) => + useClientDataSWR( + [SWR_USE_FETCH_MESSAGES, sessionId, activeTopicId], + async ([, sessionId, topicId]: [string, string, string | undefined]) => messageService.getMessages(sessionId, topicId), { onSuccess: (messages, key) => { @@ -289,14 +292,10 @@ export const chatMessage: StateCreator< }), ); }, - // default is 2000ms ,it makes the user's quick switch don't work correctly. - // Cause issue like this: https://github.com/lobehub/lobe-chat/issues/532 - // we need to set it to 0. - dedupingInterval: 0, }, ), refreshMessages: async () => { - await mutate([get().activeId, get().activeTopicId]); + await mutate([SWR_USE_FETCH_MESSAGES, get().activeId, get().activeTopicId]); }, // the internal process method of the AI message diff --git a/src/store/session/slices/session/action.ts b/src/store/session/slices/session/action.ts index 970df826e145..bb3414989b0a 100644 --- a/src/store/session/slices/session/action.ts +++ b/src/store/session/slices/session/action.ts @@ -6,6 +6,7 @@ import { StateCreator } from 'zustand/vanilla'; import { INBOX_SESSION_ID } from '@/const/session'; import { SESSION_CHAT_URL } from '@/const/url'; +import { useClientDataSWR } from '@/libs/swr'; import { sessionService } from '@/services/session'; import { useGlobalStore } from '@/store/global'; import { settingsSelectors } from '@/store/global/selectors'; @@ -154,7 +155,7 @@ export const createSessionSlice: StateCreator< }, useFetchSessions: () => - useSWR(FETCH_SESSIONS_KEY, sessionService.getSessionsWithGroup, { + useClientDataSWR(FETCH_SESSIONS_KEY, sessionService.getSessionsWithGroup, { onSuccess: (data) => { // 由于 https://github.com/lobehub/lobe-chat/pull/541 的关系 // 只有触发了 refreshSessions 才会更新 sessions,进而触发页面 rerender