From c0f94759518a9dbe12f91b47b61b5897eace060d Mon Sep 17 00:00:00 2001 From: Henry Date: Sat, 9 Nov 2024 03:25:40 +0000 Subject: [PATCH 1/2] add postgres to agent memory --- .../nodes/memory/AgentMemory/AgentMemory.ts | 68 ++++- .../nodes/memory/AgentMemory/pgSaver.ts | 250 ++++++++++++++++++ 2 files changed, 315 insertions(+), 3 deletions(-) create mode 100644 packages/components/nodes/memory/AgentMemory/pgSaver.ts diff --git a/packages/components/nodes/memory/AgentMemory/AgentMemory.ts b/packages/components/nodes/memory/AgentMemory/AgentMemory.ts index 2542a73ed90..de8d274c665 100644 --- a/packages/components/nodes/memory/AgentMemory/AgentMemory.ts +++ b/packages/components/nodes/memory/AgentMemory/AgentMemory.ts @@ -1,9 +1,10 @@ import path from 'path' -import { getBaseClasses, getUserHome } from '../../../src/utils' +import { getBaseClasses, getCredentialData, getCredentialParam, getUserHome } from '../../../src/utils' import { SaverOptions } from './interface' import { ICommonObject, IDatabaseEntity, INode, INodeData, INodeParams } from '../../../src/Interface' import { SqliteSaver } from './sqliteSaver' import { DataSource } from 'typeorm' +import { PostgresSaver } from './pgSaver' class AgentMemory_Memory implements INode { label: string @@ -16,16 +17,24 @@ class AgentMemory_Memory implements INode { badge: string baseClasses: string[] inputs: INodeParams[] + credential: INodeParams constructor() { this.label = 'Agent Memory' this.name = 'agentMemory' - this.version = 1.0 + this.version = 2.0 this.type = 'AgentMemory' this.icon = 'agentmemory.svg' this.category = 'Memory' this.description = 'Memory for agentflow to remember the state of the conversation' this.baseClasses = [this.type, ...getBaseClasses(SqliteSaver)] + this.credential = { + label: 'Connect Credential', + name: 'credential', + type: 'credential', + credentialNames: ['PostgresApi'], + optional: true + } this.inputs = [ { label: 'Database', @@ -35,6 +44,10 @@ class AgentMemory_Memory implements INode { { label: 'SQLite', name: 'sqlite' + }, + { + label: 'Postgres', + name: 'postgres' } ], default: 'sqlite' @@ -49,6 +62,31 @@ class AgentMemory_Memory implements INode { additionalParams: true, optional: true }, + { + label: 'Host', + name: 'host', + type: 'string', + description: 'If Postgres is selected, provide the host of the Postgres database', + additionalParams: true, + optional: true + }, + { + label: 'Database', + name: 'database', + type: 'string', + description: 'If Postgres is selected, provide the name of the Postgres database', + additionalParams: true, + optional: true + }, + { + label: 'Port', + name: 'port', + type: 'number', + description: 'If Postgres is selected, provide the port of the Postgres database', + placeholder: '5432', + additionalParams: true, + optional: true + }, { label: 'Additional Connection Configuration', name: 'additionalConfig', @@ -78,7 +116,7 @@ class AgentMemory_Memory implements INode { const threadId = options.sessionId || options.chatId - const datasourceOptions: ICommonObject = { + let datasourceOptions: ICommonObject = { ...additionalConfiguration, type: databaseType } @@ -96,6 +134,30 @@ class AgentMemory_Memory implements INode { } const recordManager = new SqliteSaver(args) return recordManager + } else if (databaseType === 'postgres') { + const credentialData = await getCredentialData(nodeData.credential ?? '', options) + const user = getCredentialParam('user', credentialData, nodeData) + const password = getCredentialParam('password', credentialData, nodeData) + const _port = (nodeData.inputs?.port as string) || '5432' + const port = parseInt(_port) + datasourceOptions = { + ...datasourceOptions, + host: nodeData.inputs?.host as string, + port, + database: nodeData.inputs?.database as string, + username: user, + user: user, + password: password + } + const args: SaverOptions = { + datasourceOptions, + threadId, + appDataSource, + databaseEntities, + chatflowid + } + const recordManager = new PostgresSaver(args) + return recordManager } return undefined diff --git a/packages/components/nodes/memory/AgentMemory/pgSaver.ts b/packages/components/nodes/memory/AgentMemory/pgSaver.ts new file mode 100644 index 00000000000..27e236a7f7a --- /dev/null +++ b/packages/components/nodes/memory/AgentMemory/pgSaver.ts @@ -0,0 +1,250 @@ +import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph' +import { RunnableConfig } from '@langchain/core/runnables' +import { BaseMessage } from '@langchain/core/messages' +import { DataSource, QueryRunner } from 'typeorm' +import { CheckpointTuple, SaverOptions, SerializerProtocol } from './interface' +import { IMessage, MemoryMethods } from '../../../src/Interface' +import { mapChatMessageToBaseMessage } from '../../../src/utils' + +export class PostgresSaver extends BaseCheckpointSaver implements MemoryMethods { + protected isSetup: boolean + + datasource: DataSource + + queryRunner: QueryRunner + + config: SaverOptions + + threadId: string + + tableName = 'checkpoints' + + constructor(config: SaverOptions, serde?: SerializerProtocol) { + super(serde) + this.config = config + const { datasourceOptions, threadId } = config + this.threadId = threadId + this.datasource = new DataSource(datasourceOptions) + } + + private async setup(): Promise { + if (this.isSetup) { + return + } + + try { + const appDataSource = await this.datasource.initialize() + + this.queryRunner = appDataSource.createQueryRunner() + await this.queryRunner.manager.query(` +CREATE TABLE IF NOT EXISTS ${this.tableName} ( + thread_id TEXT NOT NULL, + checkpoint_id TEXT NOT NULL, + parent_id TEXT, + checkpoint BYTEA, + metadata BYTEA, + PRIMARY KEY (thread_id, checkpoint_id));`) + } catch (error) { + console.error(`Error creating ${this.tableName} table`, error) + throw new Error(`Error creating ${this.tableName} table`) + } + + this.isSetup = true + } + + async getTuple(config: RunnableConfig): Promise { + await this.setup() + const thread_id = config.configurable?.thread_id || this.threadId + const checkpoint_id = config.configurable?.checkpoint_id + + if (checkpoint_id) { + try { + const keys = [thread_id, checkpoint_id] + const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = $1 AND checkpoint_id = $2` + + const rows = await this.queryRunner.manager.query(sql, keys) + + if (rows && rows.length > 0) { + return { + config, + checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint, + metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata, + parentConfig: rows[0].parent_id + ? { + configurable: { + thread_id, + checkpoint_id: rows[0].parent_id + } + } + : undefined + } + } + } catch (error) { + console.error(`Error retrieving ${this.tableName}`, error) + throw new Error(`Error retrieving ${this.tableName}`) + } + } else { + const keys = [thread_id] + const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = $1 ORDER BY checkpoint_id DESC LIMIT 1` + + const rows = await this.queryRunner.manager.query(sql, keys) + + if (rows && rows.length > 0) { + return { + config: { + configurable: { + thread_id: rows[0].thread_id, + checkpoint_id: rows[0].checkpoint_id + } + }, + checkpoint: (await this.serde.parse(rows[0].checkpoint)) as Checkpoint, + metadata: (await this.serde.parse(rows[0].metadata)) as CheckpointMetadata, + parentConfig: rows[0].parent_id + ? { + configurable: { + thread_id: rows[0].thread_id, + checkpoint_id: rows[0].parent_id + } + } + : undefined + } + } + } + return undefined + } + + async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator { + await this.setup() + const thread_id = config.configurable?.thread_id || this.threadId + let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = $1` + const args = [thread_id] + + if (before?.configurable?.checkpoint_id) { + sql += ' AND checkpoint_id < $2' + args.push(before.configurable.checkpoint_id) + } + + sql += ' ORDER BY checkpoint_id DESC' + if (limit) { + sql += ` LIMIT ${limit}` + } + + try { + const rows = await this.queryRunner.manager.query(sql, args) + + if (rows && rows.length > 0) { + for (const row of rows) { + yield { + config: { + configurable: { + thread_id: row.thread_id, + checkpoint_id: row.checkpoint_id + } + }, + checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint, + metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata, + parentConfig: row.parent_id + ? { + configurable: { + thread_id: row.thread_id, + checkpoint_id: row.parent_id + } + } + : undefined + } + } + } + } catch (error) { + console.error(`Error listing ${this.tableName}`, error) + throw new Error(`Error listing ${this.tableName}`) + } + } + + async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise { + await this.setup() + if (!config.configurable?.checkpoint_id) return {} + try { + const row = [ + config.configurable?.thread_id || this.threadId, + checkpoint.id, + config.configurable?.checkpoint_id, + Buffer.from(this.serde.stringify(checkpoint)), // Encode to binary + Buffer.from(this.serde.stringify(metadata)) // Encode to binary + ] + + const query = `INSERT INTO ${this.tableName} (thread_id, checkpoint_id, parent_id, checkpoint, metadata) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (thread_id, checkpoint_id) + DO UPDATE SET checkpoint = EXCLUDED.checkpoint, metadata = EXCLUDED.metadata` + + await this.queryRunner.manager.query(query, row) + } catch (error) { + console.error('Error saving checkpoint', error) + throw new Error('Error saving checkpoint') + } + + return { + configurable: { + thread_id: config.configurable?.thread_id || this.threadId, + checkpoint_id: checkpoint.id + } + } + } + + async delete(threadId: string): Promise { + if (!threadId) { + return + } + await this.setup() + const query = `DELETE FROM "${this.tableName}" WHERE thread_id = $1;` + + try { + await this.queryRunner.manager.query(query, [threadId]) + } catch (error) { + console.error(`Error deleting thread_id ${threadId}`, error) + } + } + + async getChatMessages( + overrideSessionId = '', + returnBaseMessages = false, + prependMessages?: IMessage[] + ): Promise { + if (!overrideSessionId) return [] + + const chatMessage = await this.config.appDataSource.getRepository(this.config.databaseEntities['ChatMessage']).find({ + where: { + sessionId: overrideSessionId, + chatflowid: this.config.chatflowid + }, + order: { + createdDate: 'ASC' + } + }) + + if (prependMessages?.length) { + chatMessage.unshift(...prependMessages) + } + + if (returnBaseMessages) { + return await mapChatMessageToBaseMessage(chatMessage) + } + + let returnIMessages: IMessage[] = [] + for (const m of chatMessage) { + returnIMessages.push({ + message: m.content as string, + type: m.role + }) + } + return returnIMessages + } + + async addChatMessages(): Promise { + // Empty as it's not being used + } + + async clearChatMessages(overrideSessionId = ''): Promise { + await this.delete(overrideSessionId) + } +} From a9919b45b1ed76362e6bf99f1c371052dfd46133 Mon Sep 17 00:00:00 2001 From: Henry Date: Sat, 9 Nov 2024 14:07:11 +0000 Subject: [PATCH 2/2] add mysql agent memory --- .../nodes/memory/AgentMemory/AgentMemory.ts | 41 ++- .../nodes/memory/AgentMemory/mysqlSaver.ts | 245 ++++++++++++++++++ 2 files changed, 280 insertions(+), 6 deletions(-) create mode 100644 packages/components/nodes/memory/AgentMemory/mysqlSaver.ts diff --git a/packages/components/nodes/memory/AgentMemory/AgentMemory.ts b/packages/components/nodes/memory/AgentMemory/AgentMemory.ts index de8d274c665..33ef61cb2c8 100644 --- a/packages/components/nodes/memory/AgentMemory/AgentMemory.ts +++ b/packages/components/nodes/memory/AgentMemory/AgentMemory.ts @@ -5,6 +5,7 @@ import { ICommonObject, IDatabaseEntity, INode, INodeData, INodeParams } from '. import { SqliteSaver } from './sqliteSaver' import { DataSource } from 'typeorm' import { PostgresSaver } from './pgSaver' +import { MySQLSaver } from './mysqlSaver' class AgentMemory_Memory implements INode { label: string @@ -32,7 +33,7 @@ class AgentMemory_Memory implements INode { label: 'Connect Credential', name: 'credential', type: 'credential', - credentialNames: ['PostgresApi'], + credentialNames: ['PostgresApi', 'MySQLApi'], optional: true } this.inputs = [ @@ -46,8 +47,12 @@ class AgentMemory_Memory implements INode { name: 'sqlite' }, { - label: 'Postgres', + label: 'PostgreSQL', name: 'postgres' + }, + { + label: 'MySQL', + name: 'mysql' } ], default: 'sqlite' @@ -66,7 +71,7 @@ class AgentMemory_Memory implements INode { label: 'Host', name: 'host', type: 'string', - description: 'If Postgres is selected, provide the host of the Postgres database', + description: 'If PostgresQL/MySQL is selected, provide the host of the database', additionalParams: true, optional: true }, @@ -74,7 +79,7 @@ class AgentMemory_Memory implements INode { label: 'Database', name: 'database', type: 'string', - description: 'If Postgres is selected, provide the name of the Postgres database', + description: 'If PostgresQL/MySQL is selected, provide the name of the database', additionalParams: true, optional: true }, @@ -82,8 +87,7 @@ class AgentMemory_Memory implements INode { label: 'Port', name: 'port', type: 'number', - description: 'If Postgres is selected, provide the port of the Postgres database', - placeholder: '5432', + description: 'If PostgresQL/MySQL is selected, provide the port of the database', additionalParams: true, optional: true }, @@ -158,6 +162,31 @@ class AgentMemory_Memory implements INode { } const recordManager = new PostgresSaver(args) return recordManager + } else if (databaseType === 'mysql') { + const credentialData = await getCredentialData(nodeData.credential ?? '', options) + const user = getCredentialParam('user', credentialData, nodeData) + const password = getCredentialParam('password', credentialData, nodeData) + const _port = (nodeData.inputs?.port as string) || '3306' + const port = parseInt(_port) + datasourceOptions = { + ...datasourceOptions, + host: nodeData.inputs?.host as string, + port, + database: nodeData.inputs?.database as string, + username: user, + user: user, + password: password, + charset: 'utf8mb4' + } + const args: SaverOptions = { + datasourceOptions, + threadId, + appDataSource, + databaseEntities, + chatflowid + } + const recordManager = new MySQLSaver(args) + return recordManager } return undefined diff --git a/packages/components/nodes/memory/AgentMemory/mysqlSaver.ts b/packages/components/nodes/memory/AgentMemory/mysqlSaver.ts new file mode 100644 index 00000000000..05bfcfc0b0d --- /dev/null +++ b/packages/components/nodes/memory/AgentMemory/mysqlSaver.ts @@ -0,0 +1,245 @@ +import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph' +import { RunnableConfig } from '@langchain/core/runnables' +import { BaseMessage } from '@langchain/core/messages' +import { DataSource, QueryRunner } from 'typeorm' +import { CheckpointTuple, SaverOptions, SerializerProtocol } from './interface' +import { IMessage, MemoryMethods } from '../../../src/Interface' +import { mapChatMessageToBaseMessage } from '../../../src/utils' + +export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods { + protected isSetup: boolean + + datasource: DataSource + + queryRunner: QueryRunner + + config: SaverOptions + + threadId: string + + tableName = 'checkpoints' + + constructor(config: SaverOptions, serde?: SerializerProtocol) { + super(serde) + this.config = config + const { datasourceOptions, threadId } = config + this.threadId = threadId + this.datasource = new DataSource(datasourceOptions) + } + + private async setup(): Promise { + if (this.isSetup) { + return + } + + try { + const appDataSource = await this.datasource.initialize() + + this.queryRunner = appDataSource.createQueryRunner() + await this.queryRunner.manager.query(` +CREATE TABLE IF NOT EXISTS ${this.tableName} ( + thread_id VARCHAR(255) NOT NULL, + checkpoint_id VARCHAR(255) NOT NULL, + parent_id VARCHAR(255), + checkpoint BLOB, + metadata BLOB, + PRIMARY KEY (thread_id, checkpoint_id) +);`) + } catch (error) { + console.error(`Error creating ${this.tableName} table`, error) + throw new Error(`Error creating ${this.tableName} table`) + } + + this.isSetup = true + } + + async getTuple(config: RunnableConfig): Promise { + await this.setup() + const thread_id = config.configurable?.thread_id || this.threadId + const checkpoint_id = config.configurable?.checkpoint_id + + if (checkpoint_id) { + try { + const keys = [thread_id, checkpoint_id] + const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?` + + const rows = await this.queryRunner.manager.query(sql, keys) + + if (rows && rows.length > 0) { + return { + config, + checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint, + metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata, + parentConfig: rows[0].parent_id + ? { + configurable: { + thread_id, + checkpoint_id: rows[0].parent_id + } + } + : undefined + } + } + } catch (error) { + console.error(`Error retrieving ${this.tableName}`, error) + throw new Error(`Error retrieving ${this.tableName}`) + } + } else { + const keys = [thread_id] + const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1` + + const rows = await this.queryRunner.manager.query(sql, keys) + + if (rows && rows.length > 0) { + return { + config: { + configurable: { + thread_id: rows[0].thread_id, + checkpoint_id: rows[0].checkpoint_id + } + }, + checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint, + metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata, + parentConfig: rows[0].parent_id + ? { + configurable: { + thread_id: rows[0].thread_id, + checkpoint_id: rows[0].parent_id + } + } + : undefined + } + } + } + return undefined + } + + async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator { + await this.setup() + const thread_id = config.configurable?.thread_id || this.threadId + let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${ + before ? 'AND checkpoint_id < ?' : '' + } ORDER BY checkpoint_id DESC` + if (limit) { + sql += ` LIMIT ${limit}` + } + const args = [thread_id, before?.configurable?.checkpoint_id].filter(Boolean) + + try { + const rows = await this.queryRunner.manager.query(sql, args) + + if (rows && rows.length > 0) { + for (const row of rows) { + yield { + config: { + configurable: { + thread_id: row.thread_id, + checkpoint_id: row.checkpoint_id + } + }, + checkpoint: (await this.serde.parse(row.checkpoint.toString())) as Checkpoint, + metadata: (await this.serde.parse(row.metadata.toString())) as CheckpointMetadata, + parentConfig: row.parent_id + ? { + configurable: { + thread_id: row.thread_id, + checkpoint_id: row.parent_id + } + } + : undefined + } + } + } + } catch (error) { + console.error(`Error listing ${this.tableName}`, error) + throw new Error(`Error listing ${this.tableName}`) + } + } + + async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise { + await this.setup() + if (!config.configurable?.checkpoint_id) return {} + try { + const row = [ + config.configurable?.thread_id || this.threadId, + checkpoint.id, + config.configurable?.checkpoint_id, + Buffer.from(this.serde.stringify(checkpoint)), // Encode to binary + Buffer.from(this.serde.stringify(metadata)) // Encode to binary + ] + + const query = `INSERT INTO ${this.tableName} (thread_id, checkpoint_id, parent_id, checkpoint, metadata) + VALUES (?, ?, ?, ?, ?) + ON DUPLICATE KEY UPDATE checkpoint = VALUES(checkpoint), metadata = VALUES(metadata)` + + await this.queryRunner.manager.query(query, row) + } catch (error) { + console.error('Error saving checkpoint', error) + throw new Error('Error saving checkpoint') + } + + return { + configurable: { + thread_id: config.configurable?.thread_id || this.threadId, + checkpoint_id: checkpoint.id + } + } + } + + async delete(threadId: string): Promise { + if (!threadId) { + return + } + await this.setup() + const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;` + + try { + await this.queryRunner.manager.query(query, [threadId]) + } catch (error) { + console.error(`Error deleting thread_id ${threadId}`, error) + } + } + + async getChatMessages( + overrideSessionId = '', + returnBaseMessages = false, + prependMessages?: IMessage[] + ): Promise { + if (!overrideSessionId) return [] + + const chatMessage = await this.config.appDataSource.getRepository(this.config.databaseEntities['ChatMessage']).find({ + where: { + sessionId: overrideSessionId, + chatflowid: this.config.chatflowid + }, + order: { + createdDate: 'ASC' + } + }) + + if (prependMessages?.length) { + chatMessage.unshift(...prependMessages) + } + + if (returnBaseMessages) { + return await mapChatMessageToBaseMessage(chatMessage) + } + + let returnIMessages: IMessage[] = [] + for (const m of chatMessage) { + returnIMessages.push({ + message: m.content as string, + type: m.role + }) + } + return returnIMessages + } + + async addChatMessages(): Promise { + // Empty as it's not being used + } + + async clearChatMessages(overrideSessionId = ''): Promise { + await this.delete(overrideSessionId) + } +}