Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/Postgres agent memory #3495

Merged
merged 2 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 94 additions & 3 deletions packages/components/nodes/memory/AgentMemory/AgentMemory.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import path from 'path'
import { getBaseClasses, getUserHome } from '../../../src/utils'
import { getBaseClasses, getCredentialData, getCredentialParam, getUserHome } from '../../../src/utils'
import { SaverOptions } from './interface'
import { ICommonObject, IDatabaseEntity, INode, INodeData, INodeParams } from '../../../src/Interface'
import { SqliteSaver } from './sqliteSaver'
import { DataSource } from 'typeorm'
import { PostgresSaver } from './pgSaver'
import { MySQLSaver } from './mysqlSaver'

class AgentMemory_Memory implements INode {
label: string
Expand All @@ -16,16 +18,24 @@ class AgentMemory_Memory implements INode {
badge: string
baseClasses: string[]
inputs: INodeParams[]
credential: INodeParams

constructor() {
this.label = 'Agent Memory'
this.name = 'agentMemory'
this.version = 1.0
this.version = 2.0
this.type = 'AgentMemory'
this.icon = 'agentmemory.svg'
this.category = 'Memory'
this.description = 'Memory for agentflow to remember the state of the conversation'
this.baseClasses = [this.type, ...getBaseClasses(SqliteSaver)]
this.credential = {
label: 'Connect Credential',
name: 'credential',
type: 'credential',
credentialNames: ['PostgresApi', 'MySQLApi'],
optional: true
}
this.inputs = [
{
label: 'Database',
Expand All @@ -35,6 +45,14 @@ class AgentMemory_Memory implements INode {
{
label: 'SQLite',
name: 'sqlite'
},
{
label: 'PostgreSQL',
name: 'postgres'
},
{
label: 'MySQL',
name: 'mysql'
}
],
default: 'sqlite'
Expand All @@ -49,6 +67,30 @@ class AgentMemory_Memory implements INode {
additionalParams: true,
optional: true
},
{
label: 'Host',
name: 'host',
type: 'string',
description: 'If PostgresQL/MySQL is selected, provide the host of the database',
additionalParams: true,
optional: true
},
{
label: 'Database',
name: 'database',
type: 'string',
description: 'If PostgresQL/MySQL is selected, provide the name of the database',
additionalParams: true,
optional: true
},
{
label: 'Port',
name: 'port',
type: 'number',
description: 'If PostgresQL/MySQL is selected, provide the port of the database',
additionalParams: true,
optional: true
},
{
label: 'Additional Connection Configuration',
name: 'additionalConfig',
Expand Down Expand Up @@ -78,7 +120,7 @@ class AgentMemory_Memory implements INode {

const threadId = options.sessionId || options.chatId

const datasourceOptions: ICommonObject = {
let datasourceOptions: ICommonObject = {
...additionalConfiguration,
type: databaseType
}
Expand All @@ -96,6 +138,55 @@ class AgentMemory_Memory implements INode {
}
const recordManager = new SqliteSaver(args)
return recordManager
} else if (databaseType === 'postgres') {
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
const user = getCredentialParam('user', credentialData, nodeData)
const password = getCredentialParam('password', credentialData, nodeData)
const _port = (nodeData.inputs?.port as string) || '5432'
const port = parseInt(_port)
datasourceOptions = {
...datasourceOptions,
host: nodeData.inputs?.host as string,
port,
database: nodeData.inputs?.database as string,
username: user,
user: user,
password: password
}
const args: SaverOptions = {
datasourceOptions,
threadId,
appDataSource,
databaseEntities,
chatflowid
}
const recordManager = new PostgresSaver(args)
return recordManager
} else if (databaseType === 'mysql') {
const credentialData = await getCredentialData(nodeData.credential ?? '', options)
const user = getCredentialParam('user', credentialData, nodeData)
const password = getCredentialParam('password', credentialData, nodeData)
const _port = (nodeData.inputs?.port as string) || '3306'
const port = parseInt(_port)
datasourceOptions = {
...datasourceOptions,
host: nodeData.inputs?.host as string,
port,
database: nodeData.inputs?.database as string,
username: user,
user: user,
password: password,
charset: 'utf8mb4'
}
const args: SaverOptions = {
datasourceOptions,
threadId,
appDataSource,
databaseEntities,
chatflowid
}
const recordManager = new MySQLSaver(args)
return recordManager
}

return undefined
Expand Down
245 changes: 245 additions & 0 deletions packages/components/nodes/memory/AgentMemory/mysqlSaver.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import { BaseCheckpointSaver, Checkpoint, CheckpointMetadata } from '@langchain/langgraph'
import { RunnableConfig } from '@langchain/core/runnables'
import { BaseMessage } from '@langchain/core/messages'
import { DataSource, QueryRunner } from 'typeorm'
import { CheckpointTuple, SaverOptions, SerializerProtocol } from './interface'
import { IMessage, MemoryMethods } from '../../../src/Interface'
import { mapChatMessageToBaseMessage } from '../../../src/utils'

export class MySQLSaver extends BaseCheckpointSaver implements MemoryMethods {
protected isSetup: boolean

datasource: DataSource

queryRunner: QueryRunner

config: SaverOptions

threadId: string

tableName = 'checkpoints'

constructor(config: SaverOptions, serde?: SerializerProtocol<Checkpoint>) {
super(serde)
this.config = config
const { datasourceOptions, threadId } = config
this.threadId = threadId
this.datasource = new DataSource(datasourceOptions)
}

private async setup(): Promise<void> {
if (this.isSetup) {
return
}

try {
const appDataSource = await this.datasource.initialize()

this.queryRunner = appDataSource.createQueryRunner()
await this.queryRunner.manager.query(`
CREATE TABLE IF NOT EXISTS ${this.tableName} (
thread_id VARCHAR(255) NOT NULL,
checkpoint_id VARCHAR(255) NOT NULL,
parent_id VARCHAR(255),
checkpoint BLOB,
metadata BLOB,
PRIMARY KEY (thread_id, checkpoint_id)
);`)
} catch (error) {
console.error(`Error creating ${this.tableName} table`, error)
throw new Error(`Error creating ${this.tableName} table`)
}

this.isSetup = true
}

async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
await this.setup()
const thread_id = config.configurable?.thread_id || this.threadId
const checkpoint_id = config.configurable?.checkpoint_id

if (checkpoint_id) {
try {
const keys = [thread_id, checkpoint_id]
const sql = `SELECT checkpoint, parent_id, metadata FROM ${this.tableName} WHERE thread_id = ? AND checkpoint_id = ?`

const rows = await this.queryRunner.manager.query(sql, keys)

if (rows && rows.length > 0) {
return {
config,
checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint,
metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata,
parentConfig: rows[0].parent_id
? {
configurable: {
thread_id,
checkpoint_id: rows[0].parent_id
}
}
: undefined
}
}
} catch (error) {
console.error(`Error retrieving ${this.tableName}`, error)
throw new Error(`Error retrieving ${this.tableName}`)
}
} else {
const keys = [thread_id]
const sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ORDER BY checkpoint_id DESC LIMIT 1`

const rows = await this.queryRunner.manager.query(sql, keys)

if (rows && rows.length > 0) {
return {
config: {
configurable: {
thread_id: rows[0].thread_id,
checkpoint_id: rows[0].checkpoint_id
}
},
checkpoint: (await this.serde.parse(rows[0].checkpoint.toString())) as Checkpoint,
metadata: (await this.serde.parse(rows[0].metadata.toString())) as CheckpointMetadata,
parentConfig: rows[0].parent_id
? {
configurable: {
thread_id: rows[0].thread_id,
checkpoint_id: rows[0].parent_id
}
}
: undefined
}
}
}
return undefined
}

async *list(config: RunnableConfig, limit?: number, before?: RunnableConfig): AsyncGenerator<CheckpointTuple> {
await this.setup()
const thread_id = config.configurable?.thread_id || this.threadId
let sql = `SELECT thread_id, checkpoint_id, parent_id, checkpoint, metadata FROM ${this.tableName} WHERE thread_id = ? ${
before ? 'AND checkpoint_id < ?' : ''
} ORDER BY checkpoint_id DESC`
if (limit) {
sql += ` LIMIT ${limit}`
}
const args = [thread_id, before?.configurable?.checkpoint_id].filter(Boolean)

try {
const rows = await this.queryRunner.manager.query(sql, args)

if (rows && rows.length > 0) {
for (const row of rows) {
yield {
config: {
configurable: {
thread_id: row.thread_id,
checkpoint_id: row.checkpoint_id
}
},
checkpoint: (await this.serde.parse(row.checkpoint.toString())) as Checkpoint,
metadata: (await this.serde.parse(row.metadata.toString())) as CheckpointMetadata,
parentConfig: row.parent_id
? {
configurable: {
thread_id: row.thread_id,
checkpoint_id: row.parent_id
}
}
: undefined
}
}
}
} catch (error) {
console.error(`Error listing ${this.tableName}`, error)
throw new Error(`Error listing ${this.tableName}`)
}
}

async put(config: RunnableConfig, checkpoint: Checkpoint, metadata: CheckpointMetadata): Promise<RunnableConfig> {
await this.setup()
if (!config.configurable?.checkpoint_id) return {}
try {
const row = [
config.configurable?.thread_id || this.threadId,
checkpoint.id,
config.configurable?.checkpoint_id,
Buffer.from(this.serde.stringify(checkpoint)), // Encode to binary
Buffer.from(this.serde.stringify(metadata)) // Encode to binary
]

const query = `INSERT INTO ${this.tableName} (thread_id, checkpoint_id, parent_id, checkpoint, metadata)
VALUES (?, ?, ?, ?, ?)
ON DUPLICATE KEY UPDATE checkpoint = VALUES(checkpoint), metadata = VALUES(metadata)`

await this.queryRunner.manager.query(query, row)
} catch (error) {
console.error('Error saving checkpoint', error)
throw new Error('Error saving checkpoint')
}

return {
configurable: {
thread_id: config.configurable?.thread_id || this.threadId,
checkpoint_id: checkpoint.id
}
}
}

async delete(threadId: string): Promise<void> {
if (!threadId) {
return
}
await this.setup()
const query = `DELETE FROM ${this.tableName} WHERE thread_id = ?;`

try {
await this.queryRunner.manager.query(query, [threadId])
} catch (error) {
console.error(`Error deleting thread_id ${threadId}`, error)
}
}

async getChatMessages(
overrideSessionId = '',
returnBaseMessages = false,
prependMessages?: IMessage[]
): Promise<IMessage[] | BaseMessage[]> {
if (!overrideSessionId) return []

const chatMessage = await this.config.appDataSource.getRepository(this.config.databaseEntities['ChatMessage']).find({
where: {
sessionId: overrideSessionId,
chatflowid: this.config.chatflowid
},
order: {
createdDate: 'ASC'
}
})

if (prependMessages?.length) {
chatMessage.unshift(...prependMessages)
}

if (returnBaseMessages) {
return await mapChatMessageToBaseMessage(chatMessage)
}

let returnIMessages: IMessage[] = []
for (const m of chatMessage) {
returnIMessages.push({
message: m.content as string,
type: m.role
})
}
return returnIMessages
}

async addChatMessages(): Promise<void> {
// Empty as it's not being used
}

async clearChatMessages(overrideSessionId = ''): Promise<void> {
await this.delete(overrideSessionId)
}
}
Loading
Loading