From 01eb5a1da95abd1c20a3a9b89c80533141c398c9 Mon Sep 17 00:00:00 2001 From: v1xingyue Date: Mon, 24 Feb 2025 10:02:26 +0800 Subject: [PATCH 01/13] add agent server opitons with middleware settings --- packages/agent/src/server/index.ts | 1705 ++++++++++++++-------------- 1 file changed, 852 insertions(+), 853 deletions(-) diff --git a/packages/agent/src/server/index.ts b/packages/agent/src/server/index.ts index 353f6e32d11..7d9667e00a4 100644 --- a/packages/agent/src/server/index.ts +++ b/packages/agent/src/server/index.ts @@ -1,16 +1,16 @@ import { - ChannelType, - composeContext, - generateMessageResponse, - generateObject, - logger, - ModelClass, - stringToUuid, - type Character, - type Content, - type IAgentRuntime, - type Media, - type Memory + ChannelType, + composeContext, + generateMessageResponse, + generateObject, + logger, + ModelClass, + stringToUuid, + type Character, + type Content, + type IAgentRuntime, + type Media, + type Memory, } from "@elizaos/core"; import bodyParser from "body-parser"; import cors from "cors"; @@ -20,888 +20,887 @@ import * as fs from "node:fs"; import * as path from "node:path"; import { z } from "zod"; import { createApiRouter } from "./api.ts"; -import { hyperfiHandlerTemplate, messageHandlerTemplate, upload } from "./helper.ts"; +import { + hyperfiHandlerTemplate, + messageHandlerTemplate, + upload, +} from "./helper.ts"; import replyAction from "./reply.ts"; +export interface ServerMiddleware { + ( + req: express.Request, + res: express.Response, + next: express.NextFunction + ): void; +} + +export interface ServerOptions { + middlewares?: ServerMiddleware[]; +} + export class CharacterServer { - public app: express.Application; - private agents: Map; // container management - private server: any; // Store server instance - public startAgent: (character: Character) => Promise; // Store startAgent function - public loadCharacterTryPath: (characterPath: string) => Promise; // Store loadCharacterTryPath function - public jsonToCharacter: (filePath: string, character: string | never) => Promise; // Store jsonToCharacter function - - constructor() { - logger.log("DirectClient constructor"); - this.app = express(); - this.app.use(cors()); - this.agents = new Map(); - - this.app.use(bodyParser.json()); - this.app.use(bodyParser.urlencoded({ extended: true })); - - // Serve both uploads and generated images - this.app.use( - "/media/uploads", - express.static(path.join(process.cwd(), "/data/uploads")) - ); - this.app.use( - "/media/generated", - express.static(path.join(process.cwd(), "/generatedImages")) - ); + public app: express.Application; + private agents: Map; // container management + private server: any; // Store server instance + public startAgent: (character: Character) => Promise; // Store startAgent function + public loadCharacterTryPath: (characterPath: string) => Promise; // Store loadCharacterTryPath function + public jsonToCharacter: ( + filePath: string, + character: string | never + ) => Promise; // Store jsonToCharacter function + + constructor(options: ServerOptions) { + logger.log("DirectClient constructor"); + this.app = express(); + this.app.use(cors()); + this.agents = new Map(); + + this.app.use(bodyParser.json()); + this.app.use(bodyParser.urlencoded({ extended: true })); + + if (options.middlewares) { + for (const middleware of options.middlewares) { + this.app.use(middleware); + } + } - const apiRouter = createApiRouter(this.agents, this); - this.app.use(apiRouter); + // Serve both uploads and generated images + this.app.use( + "/media/uploads", + express.static(path.join(process.cwd(), "/data/uploads")) + ); + this.app.use( + "/media/generated", + express.static(path.join(process.cwd(), "/generatedImages")) + ); + + const apiRouter = createApiRouter(this.agents, this); + this.app.use(apiRouter); + + // Define an interface that extends the Express Request interface + interface CustomRequest extends ExpressRequest { + file?: Express.Multer.File; + } - // Define an interface that extends the Express Request interface - interface CustomRequest extends ExpressRequest { - file?: Express.Multer.File; + // Update the route handler to use CustomRequest instead of express.Request + this.app.post( + "/:agentId/whisper", + upload.single("file"), + async (req: CustomRequest, res: express.Response) => { + const audioFile = req.file; // Access the uploaded file using req.file + const agentId = req.params.agentId; + + if (!audioFile) { + res.status(400).send("No audio file provided"); + return; } - // Update the route handler to use CustomRequest instead of express.Request - this.app.post( - "/:agentId/whisper", - upload.single("file"), - async (req: CustomRequest, res: express.Response) => { - const audioFile = req.file; // Access the uploaded file using req.file - const agentId = req.params.agentId; - - if (!audioFile) { - res.status(400).send("No audio file provided"); - return; - } - - let runtime = this.agents.get(agentId); - - // if runtime is null, look for runtime with the same name - if (!runtime) { - runtime = Array.from(this.agents.values()).find( - (a) => - a.character.name.toLowerCase() === - agentId.toLowerCase() - ); - } - - if (!runtime) { - res.status(404).send("Agent not found"); - return; - } - - const audioBuffer = fs.readFileSync(audioFile.path); - const transcription = await runtime.useModel(ModelClass.TRANSCRIPTION, audioBuffer); - - res.json({text: transcription}); - } - ); + let runtime = this.agents.get(agentId); - this.app.post( - "/:agentId/message", - upload.single("file"), - async (req: express.Request, res: express.Response) => { - const agentId = req.params.agentId; - const roomId = stringToUuid( - req.body.roomId ?? `default-room-${agentId}` - ); - const userId = stringToUuid(req.body.userId ?? "user"); - - let runtime = this.agents.get(agentId); - - // if runtime is null, look for runtime with the same name - if (!runtime) { - runtime = Array.from(this.agents.values()).find( - (a) => - a.character.name.toLowerCase() === - agentId.toLowerCase() - ); - } - - if (!runtime) { - res.status(404).send("Agent not found"); - return; - } - - await runtime.ensureConnection({ - userId, - roomId, - userName: req.body.userName, - userScreenName: req.body.name, - source: "direct", - type: ChannelType.API, - }); - - const text = req.body.text; - // if empty text, directly return - if (!text) { - res.json([]); - return; - } - - const messageId = stringToUuid(Date.now().toString()); - - const attachments: Media[] = []; - if (req.file) { - const filePath = path.join( - process.cwd(), - "data", - "uploads", - req.file.filename - ); - attachments.push({ - id: Date.now().toString(), - url: filePath, - title: req.file.originalname, - source: "direct", - description: `Uploaded file: ${req.file.originalname}`, - text: "", - contentType: req.file.mimetype, - }); - } - - const content: Content = { - text, - attachments, - source: "direct", - inReplyTo: undefined, - }; - - const userMessage = { - content, - userId, - roomId, - agentId: runtime.agentId, - }; - - const memory: Memory = { - id: stringToUuid(`${messageId}-${userId}`), - ...userMessage, - agentId: runtime.agentId, - userId, - roomId, - content, - createdAt: Date.now(), - }; - - await runtime.messageManager.addEmbeddingToMemory(memory); - await runtime.messageManager.createMemory(memory); - - let state = await runtime.composeState(userMessage, { - agentName: runtime.character.name, - }); - - const context = composeContext({ - state, - template: messageHandlerTemplate, - }); - - const response = await generateMessageResponse({ - runtime: runtime, - context, - modelClass: ModelClass.TEXT_LARGE, - }); - - if (!response) { - res.status(500).send( - "No response from generateMessageResponse" - ); - return; - } - - // save response to memory - const responseMessage: Memory = { - id: stringToUuid(`${messageId}-${runtime.agentId}`), - ...userMessage, - userId: runtime.agentId, - content: response, - createdAt: Date.now(), - }; - - await runtime.messageManager.createMemory(responseMessage); - - state = await runtime.updateRecentMessageState(state); - - const replyHandler = async (message: Content) => { - res.json([message]); - return [memory]; - } + // if runtime is null, look for runtime with the same name + if (!runtime) { + runtime = Array.from(this.agents.values()).find( + (a) => a.character.name.toLowerCase() === agentId.toLowerCase() + ); + } - await runtime.processActions( - memory, - [responseMessage], - state, - replyHandler - ); + if (!runtime) { + res.status(404).send("Agent not found"); + return; + } - await runtime.evaluate(memory, state); - } + const audioBuffer = fs.readFileSync(audioFile.path); + const transcription = await runtime.useModel( + ModelClass.TRANSCRIPTION, + audioBuffer ); - this.app.post( - "/agents/:agentIdOrName/hyperfy/v1", - async (req: express.Request, res: express.Response) => { - // get runtime - const agentId = req.params.agentIdOrName; - let runtime = this.agents.get(agentId); - // if runtime is null, look for runtime with the same name - if (!runtime) { - runtime = Array.from(this.agents.values()).find( - (a) => - a.character.name.toLowerCase() === - agentId.toLowerCase() - ); - } - if (!runtime) { - res.status(404).send("Agent not found"); - return; - } - - // can we be in more than one hyperfy world at once - // but you may want the same context is multiple worlds - // this is more like an instanceId - const roomId = stringToUuid(req.body.roomId ?? "hyperfy"); - - const body = req.body; - - // hyperfy specific parameters - let nearby = []; - let availableEmotes = []; - - if (body.nearby) { - nearby = body.nearby; - } - if (body.messages) { - // loop on the messages and record the memories - // might want to do this in parallel - for (const msg of body.messages) { - const parts = msg.split(/:\s*/); - const mUserId = stringToUuid(parts[0]); - await runtime.ensureConnection({ - userId: mUserId, - roomId, // where - userName: parts[0], // username - userScreenName: parts[0], // userScreeName? - source: "hyperfy", - type: ChannelType.WORLD, - }); - const content: Content = { - text: parts[1] || "", - attachments: [], - source: "hyperfy", - inReplyTo: undefined, - }; - const memory: Memory = { - id: stringToUuid(msg), - agentId: runtime.agentId, - userId: mUserId, - roomId, - content, - }; - await runtime.messageManager.createMemory(memory); - } - } - if (body.availableEmotes) { - availableEmotes = body.availableEmotes; - } - - const content: Content = { - // we need to compose who's near and what emotes are available - text: JSON.stringify(req.body), - attachments: [], - source: "hyperfy", - inReplyTo: undefined, - }; - - const userId = stringToUuid("hyperfy"); - const userMessage = { - content, - userId, - roomId, - agentId: runtime.agentId, - }; - - const state = await runtime.composeState(userMessage, { - agentName: runtime.character.name, - }); - - let template = hyperfiHandlerTemplate; - template = template.replace( - "{{emotes}}", - availableEmotes.join("|") - ); - template = template.replace("{{nearby}}", nearby.join("|")); - const context = composeContext({ - state, - template, - }); - - function createHyperfiOutSchema( - nearby: string[], - availableEmotes: string[] - ) { - const lookAtSchema = - nearby.length > 1 - ? z - .union( - nearby.map((item) => z.literal(item)) as [ - z.ZodLiteral, - z.ZodLiteral, - ...z.ZodLiteral[], - ] - ) - .nullable() - : nearby.length === 1 - ? z.literal(nearby[0]).nullable() - : z.null(); // Fallback for empty array - - const emoteSchema = - availableEmotes.length > 1 - ? z - .union( - availableEmotes.map((item) => - z.literal(item) - ) as [ - z.ZodLiteral, - z.ZodLiteral, - ...z.ZodLiteral[], - ] - ) - .nullable() - : availableEmotes.length === 1 - ? z.literal(availableEmotes[0]).nullable() - : z.null(); // Fallback for empty array - - return z.object({ - lookAt: lookAtSchema, - emote: emoteSchema, - say: z.string().nullable(), - actions: z.array(z.string()).nullable(), - }); - } - - // Define the schema for the expected output - const hyperfiOutSchema = createHyperfiOutSchema( - nearby, - availableEmotes - ); - - // Call LLM - const response = await generateObject({ - runtime, - context, - modelClass: ModelClass.TEXT_SMALL, - schema: hyperfiOutSchema, - }); - - if (!response) { - res.status(500).send( - "No response from generateMessageResponse" - ); - return; - } - - let hfOut; - try { - hfOut = hyperfiOutSchema.parse(response.object); - } catch { - logger.error( - "cant serialize response", - response.object - ); - res.status(500).send("Error in LLM response, try again"); - return; - } - - // do this in the background - new Promise((resolve) => { - const contentObj: Content = { - text: hfOut.say, - }; - - if (hfOut.lookAt !== null || hfOut.emote !== null) { - contentObj.text += ". Then I "; - if (hfOut.lookAt !== null) { - contentObj.text += `looked at ${hfOut.lookAt}`; - if (hfOut.emote !== null) { - contentObj.text += " and "; - } - } - if (hfOut.emote !== null) { - contentObj.text = `emoted ${hfOut.emote}`; - } - } - - if (hfOut.actions !== null) { - // content can only do one action - contentObj.action = hfOut.actions[0]; - } - - // save response to memory - const responseMessage = { - ...userMessage, - userId: runtime.agentId, - content: contentObj, - }; - - runtime.messageManager - .createMemory(responseMessage) - .then(() => { - const messageId = stringToUuid( - Date.now().toString() - ); - const memory: Memory = { - id: messageId, - agentId: runtime.agentId, - userId, - roomId, - content, - createdAt: Date.now(), - }; - - // run evaluators (generally can be done in parallel with processActions) - // can an evaluator modify memory? it could but currently doesn't - runtime.evaluate(memory, state).then(() => { - // only need to call if responseMessage.content.action is set - if (contentObj.action) { - // pass memory (query) to any actions to call - runtime.processActions( - memory, - [responseMessage], - state, - async (_newMessages) => { - // FIXME: this is supposed override what the LLM said/decided - // but the promise doesn't make this possible - //message = newMessages; - return [memory]; - } - ); // 0.674s - } - resolve(true); - }); - }); - }); - res.json({ response: hfOut }); - } + res.json({ text: transcription }); + } + ); + + this.app.post( + "/:agentId/message", + upload.single("file"), + async (req: express.Request, res: express.Response) => { + const agentId = req.params.agentId; + const roomId = stringToUuid( + req.body.roomId ?? `default-room-${agentId}` ); + const userId = stringToUuid(req.body.userId ?? "user"); - this.app.post( - "/:agentId/image", - async (req: express.Request, res: express.Response) => { - const agentId = req.params.agentId; - const agent = this.agents.get(agentId); - if (!agent) { - res.status(404).send("Agent not found"); - return; - } - const images = await agent.useModel(ModelClass.IMAGE, { ...req.body }); - const imagesRes: { image: string; caption: string }[] = []; - if (images.data && images.data.length > 0) { - for (let i = 0; i < images.data.length; i++) { - const caption = await agent.useModel(ModelClass.IMAGE_DESCRIPTION, images.data[i]); - imagesRes.push({ - image: images.data[i], - caption: caption.title, - }); - } - } - res.json({ images: imagesRes }); - } - ); + let runtime = this.agents.get(agentId); - this.app.post( - "/fine-tune", - async (req: express.Request, res: express.Response) => { - try { - const response = await fetch( - "https://api.bageldb.ai/api/v1/asset", - { - method: "POST", - headers: { - "Content-Type": "application/json", - "X-API-KEY": `${process.env.BAGEL_API_KEY}`, - }, - body: JSON.stringify(req.body), - } - ); - - const data = await response.json(); - res.json(data); - } catch (error) { - res.status(500).json({ - error: "Please create an account at bakery.bagel.net and get an API key. Then set the BAGEL_API_KEY environment variable.", - details: error.message, - }); - } - } - ); + // if runtime is null, look for runtime with the same name + if (!runtime) { + runtime = Array.from(this.agents.values()).find( + (a) => a.character.name.toLowerCase() === agentId.toLowerCase() + ); + } - this.app.get( - "/fine-tune/:assetId", - async (req: express.Request, res: express.Response) => { - const assetId = req.params.assetId; - const downloadDir = path.join( - process.cwd(), - "downloads", - assetId - ); - - logger.log("Download directory:", downloadDir); - - try { - logger.log("Creating directory..."); - await fs.promises.mkdir(downloadDir, { recursive: true }); - - logger.log("Fetching file..."); - const fileResponse = await fetch( - `https://api.bageldb.ai/api/v1/asset/${assetId}/download`, - { - headers: { - "X-API-KEY": `${process.env.BAGEL_API_KEY}`, - }, - } - ); - - if (!fileResponse.ok) { - throw new Error( - `API responded with status ${fileResponse.status}: ${await fileResponse.text()}` - ); - } - - logger.log("Response headers:", fileResponse.headers); - - const fileName = - fileResponse.headers - .get("content-disposition") - ?.split("filename=")[1] - ?.replace(/"/g, /* " */ "") || "default_name.txt"; - - logger.log("Saving as:", fileName); - - const arrayBuffer = await fileResponse.arrayBuffer(); - const buffer = Buffer.from(arrayBuffer); - - const filePath = path.join(downloadDir, fileName); - logger.log("Full file path:", filePath); - - await fs.promises.writeFile(filePath, buffer); - - // Verify file was written - const stats = await fs.promises.stat(filePath); - logger.log( - "File written successfully. Size:", - stats.size, - "bytes" - ); - - res.json({ - success: true, - message: "Single file downloaded successfully", - downloadPath: downloadDir, - fileCount: 1, - fileName: fileName, - fileSize: stats.size, - }); - } catch (error) { - logger.error("Detailed error:", error); - res.status(500).json({ - error: "Failed to download files from BagelDB", - details: error.message, - stack: error.stack, - }); - } - } + if (!runtime) { + res.status(404).send("Agent not found"); + return; + } + + await runtime.ensureConnection({ + userId, + roomId, + userName: req.body.userName, + userScreenName: req.body.name, + source: "direct", + type: ChannelType.API, + }); + + const text = req.body.text; + // if empty text, directly return + if (!text) { + res.json([]); + return; + } + + const messageId = stringToUuid(Date.now().toString()); + + const attachments: Media[] = []; + if (req.file) { + const filePath = path.join( + process.cwd(), + "data", + "uploads", + req.file.filename + ); + attachments.push({ + id: Date.now().toString(), + url: filePath, + title: req.file.originalname, + source: "direct", + description: `Uploaded file: ${req.file.originalname}`, + text: "", + contentType: req.file.mimetype, + }); + } + + const content: Content = { + text, + attachments, + source: "direct", + inReplyTo: undefined, + }; + + const userMessage = { + content, + userId, + roomId, + agentId: runtime.agentId, + }; + + const memory: Memory = { + id: stringToUuid(`${messageId}-${userId}`), + ...userMessage, + agentId: runtime.agentId, + userId, + roomId, + content, + createdAt: Date.now(), + }; + + await runtime.messageManager.addEmbeddingToMemory(memory); + await runtime.messageManager.createMemory(memory); + + let state = await runtime.composeState(userMessage, { + agentName: runtime.character.name, + }); + + const context = composeContext({ + state, + template: messageHandlerTemplate, + }); + + const response = await generateMessageResponse({ + runtime: runtime, + context, + modelClass: ModelClass.TEXT_LARGE, + }); + + if (!response) { + res.status(500).send("No response from generateMessageResponse"); + return; + } + + // save response to memory + const responseMessage: Memory = { + id: stringToUuid(`${messageId}-${runtime.agentId}`), + ...userMessage, + userId: runtime.agentId, + content: response, + createdAt: Date.now(), + }; + + await runtime.messageManager.createMemory(responseMessage); + + state = await runtime.updateRecentMessageState(state); + + const replyHandler = async (message: Content) => { + res.json([message]); + return [memory]; + }; + + await runtime.processActions( + memory, + [responseMessage], + state, + replyHandler ); - this.app.post("/:agentId/speak", async (req, res) => { - const agentId = req.params.agentId; - const roomId = stringToUuid( - req.body.roomId ?? `default-room-${agentId}` - ); - const userId = stringToUuid(req.body.userId ?? "user"); - const text = req.body.text; + await runtime.evaluate(memory, state); + } + ); + + this.app.post( + "/agents/:agentIdOrName/hyperfy/v1", + async (req: express.Request, res: express.Response) => { + // get runtime + const agentId = req.params.agentIdOrName; + let runtime = this.agents.get(agentId); + // if runtime is null, look for runtime with the same name + if (!runtime) { + runtime = Array.from(this.agents.values()).find( + (a) => a.character.name.toLowerCase() === agentId.toLowerCase() + ); + } + if (!runtime) { + res.status(404).send("Agent not found"); + return; + } - if (!text) { - res.status(400).send("No text provided"); - return; - } + // can we be in more than one hyperfy world at once + // but you may want the same context is multiple worlds + // this is more like an instanceId + const roomId = stringToUuid(req.body.roomId ?? "hyperfy"); - let runtime = this.agents.get(agentId); + const body = req.body; - // if runtime is null, look for runtime with the same name - if (!runtime) { - runtime = Array.from(this.agents.values()).find( - (a) => - a.character.name.toLowerCase() === agentId.toLowerCase() - ); - } + // hyperfy specific parameters + let nearby = []; + let availableEmotes = []; - if (!runtime) { - res.status(404).send("Agent not found"); - return; - } + if (body.nearby) { + nearby = body.nearby; + } + if (body.messages) { + // loop on the messages and record the memories + // might want to do this in parallel + for (const msg of body.messages) { + const parts = msg.split(/:\s*/); + const mUserId = stringToUuid(parts[0]); + await runtime.ensureConnection({ + userId: mUserId, + roomId, // where + userName: parts[0], // username + userScreenName: parts[0], // userScreeName? + source: "hyperfy", + type: ChannelType.WORLD, + }); + const content: Content = { + text: parts[1] || "", + attachments: [], + source: "hyperfy", + inReplyTo: undefined, + }; + const memory: Memory = { + id: stringToUuid(msg), + agentId: runtime.agentId, + userId: mUserId, + roomId, + content, + }; + await runtime.messageManager.createMemory(memory); + } + } + if (body.availableEmotes) { + availableEmotes = body.availableEmotes; + } - try { - // Process message through agent (same as /message endpoint) - await runtime.ensureConnection({ - userId, - roomId, - userName: req.body.userName, - userScreenName: req.body.name, - source: "direct", - type: ChannelType.API, - }); - - const messageId = stringToUuid(Date.now().toString()); - - const content: Content = { - text, - attachments: [], - source: "direct", - inReplyTo: undefined, - }; - - const userMessage = { - content, - userId, - roomId, - agentId: runtime.agentId, - }; - - const memory: Memory = { - id: messageId, - agentId: runtime.agentId, - userId, - roomId, - content, - createdAt: Date.now(), - }; - - await runtime.messageManager.createMemory(memory); - - const state = await runtime.composeState(userMessage, { - agentName: runtime.character.name, - }); - - const context = composeContext({ - state, - template: messageHandlerTemplate, - }); - - const response = await generateMessageResponse({ - runtime: runtime, - context, - modelClass: ModelClass.TEXT_LARGE, - }); - - // save response to memory - const responseMessage = { - ...userMessage, - userId: runtime.agentId, - content: response, - }; - - await runtime.messageManager.createMemory(responseMessage); - - if (!response) { - res.status(500).send( - "No response from generateMessageResponse" - ); - return; - } - - await runtime.evaluate(memory, state); - - const _result = await runtime.processActions( - memory, - [responseMessage], - state, - async () => { - return [memory]; - } - ); - - // Get the text to convert to speech - const textToSpeak = response.text; - - const speechResponse = await runtime.useModel(ModelClass.TEXT_TO_SPEECH, textToSpeak); - - if (!speechResponse.ok) { - throw new Error( - `ElevenLabs API error: ${speechResponse.statusText}` - ); - } - - const audioBuffer = await speechResponse.arrayBuffer(); - - // Set appropriate headers for audio streaming - res.set({ - "Content-Type": "audio/mpeg", - "Transfer-Encoding": "chunked", - }); - - res.send(Buffer.from(audioBuffer)); - } catch (error) { - logger.error( - "Error processing message or generating speech:", - error - ); - res.status(500).json({ - error: "Error processing message or generating speech", - details: error.message, - }); - } + const content: Content = { + // we need to compose who's near and what emotes are available + text: JSON.stringify(req.body), + attachments: [], + source: "hyperfy", + inReplyTo: undefined, + }; + + const userId = stringToUuid("hyperfy"); + const userMessage = { + content, + userId, + roomId, + agentId: runtime.agentId, + }; + + const state = await runtime.composeState(userMessage, { + agentName: runtime.character.name, }); - this.app.post("/:agentId/tts", async (req, res) => { - const text = req.body.text; + let template = hyperfiHandlerTemplate; + template = template.replace("{{emotes}}", availableEmotes.join("|")); + template = template.replace("{{nearby}}", nearby.join("|")); + const context = composeContext({ + state, + template, + }); - if (!text) { - res.status(400).send("No text provided"); - return; - } + function createHyperfiOutSchema( + nearby: string[], + availableEmotes: string[] + ) { + const lookAtSchema = + nearby.length > 1 + ? z + .union( + nearby.map((item) => z.literal(item)) as [ + z.ZodLiteral, + z.ZodLiteral, + ...z.ZodLiteral[] + ] + ) + .nullable() + : nearby.length === 1 + ? z.literal(nearby[0]).nullable() + : z.null(); // Fallback for empty array + + const emoteSchema = + availableEmotes.length > 1 + ? z + .union( + availableEmotes.map((item) => z.literal(item)) as [ + z.ZodLiteral, + z.ZodLiteral, + ...z.ZodLiteral[] + ] + ) + .nullable() + : availableEmotes.length === 1 + ? z.literal(availableEmotes[0]).nullable() + : z.null(); // Fallback for empty array + + return z.object({ + lookAt: lookAtSchema, + emote: emoteSchema, + say: z.string().nullable(), + actions: z.array(z.string()).nullable(), + }); + } - try { - // Convert to speech using ElevenLabs - const elevenLabsApiUrl = `https://api.elevenlabs.io/v1/text-to-speech/${process.env.ELEVENLABS_VOICE_ID}`; - const apiKey = process.env.ELEVENLABS_XI_API_KEY; - - if (!apiKey) { - throw new Error("ELEVENLABS_XI_API_KEY not configured"); - } - - // TODO: Replace the process.env with settings from the character read from the database - - const speechResponse = await fetch(elevenLabsApiUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - "xi-api-key": apiKey, - }, - body: JSON.stringify({ - text, - model_id: - process.env.ELEVENLABS_MODEL_ID || - "eleven_multilingual_v2", - voice_settings: { - stability: Number.parseFloat( - process.env.ELEVENLABS_VOICE_STABILITY || "0.5" - ), - similarity_boost: Number.parseFloat( - process.env.ELEVENLABS_VOICE_SIMILARITY_BOOST || - "0.9" - ), - style: Number.parseFloat( - process.env.ELEVENLABS_VOICE_STYLE || "0.66" - ), - use_speaker_boost: - process.env - .ELEVENLABS_VOICE_USE_SPEAKER_BOOST === - "true", - }, - }), - }); - - if (!speechResponse.ok) { - throw new Error( - `ElevenLabs API error: ${speechResponse.statusText}` - ); - } - - const audioBuffer = await speechResponse.arrayBuffer(); - - res.set({ - "Content-Type": "audio/mpeg", - "Transfer-Encoding": "chunked", - }); - - res.send(Buffer.from(audioBuffer)); - } catch (error) { - logger.error( - "Error processing message or generating speech:", - error - ); - res.status(500).json({ - error: "Error processing message or generating speech", - details: error.message, - }); - } + // Define the schema for the expected output + const hyperfiOutSchema = createHyperfiOutSchema( + nearby, + availableEmotes + ); + + // Call LLM + const response = await generateObject({ + runtime, + context, + modelClass: ModelClass.TEXT_SMALL, + schema: hyperfiOutSchema, }); - } - // agent/src/index.ts:startAgent calls this - public registerAgent(runtime: IAgentRuntime) { - // register any plugin endpoints? - // but once and only once - this.agents.set(runtime.agentId, runtime); - // TODO: This is a hack to register the tee plugin. Remove this once we have a better way to do it. - const teePlugin = runtime.plugins.find(p => p.name === "phala-tee-plugin"); - if (teePlugin) { - for (const provider of teePlugin.providers) { - runtime.registerProvider(provider); + if (!response) { + res.status(500).send("No response from generateMessageResponse"); + return; + } + + let hfOut; + try { + hfOut = hyperfiOutSchema.parse(response.object); + } catch { + logger.error("cant serialize response", response.object); + res.status(500).send("Error in LLM response, try again"); + return; + } + + // do this in the background + new Promise((resolve) => { + const contentObj: Content = { + text: hfOut.say, + }; + + if (hfOut.lookAt !== null || hfOut.emote !== null) { + contentObj.text += ". Then I "; + if (hfOut.lookAt !== null) { + contentObj.text += `looked at ${hfOut.lookAt}`; + if (hfOut.emote !== null) { + contentObj.text += " and "; + } } - for (const action of teePlugin.actions) { - runtime.registerAction(action); + if (hfOut.emote !== null) { + contentObj.text = `emoted ${hfOut.emote}`; } + } + + if (hfOut.actions !== null) { + // content can only do one action + contentObj.action = hfOut.actions[0]; + } + + // save response to memory + const responseMessage = { + ...userMessage, + userId: runtime.agentId, + content: contentObj, + }; + + runtime.messageManager.createMemory(responseMessage).then(() => { + const messageId = stringToUuid(Date.now().toString()); + const memory: Memory = { + id: messageId, + agentId: runtime.agentId, + userId, + roomId, + content, + createdAt: Date.now(), + }; + + // run evaluators (generally can be done in parallel with processActions) + // can an evaluator modify memory? it could but currently doesn't + runtime.evaluate(memory, state).then(() => { + // only need to call if responseMessage.content.action is set + if (contentObj.action) { + // pass memory (query) to any actions to call + runtime.processActions( + memory, + [responseMessage], + state, + async (_newMessages) => { + // FIXME: this is supposed override what the LLM said/decided + // but the promise doesn't make this possible + //message = newMessages; + return [memory]; + } + ); // 0.674s + } + resolve(true); + }); + }); + }); + res.json({ response: hfOut }); + } + ); + + this.app.post( + "/:agentId/image", + async (req: express.Request, res: express.Response) => { + const agentId = req.params.agentId; + const agent = this.agents.get(agentId); + if (!agent) { + res.status(404).send("Agent not found"); + return; + } + const images = await agent.useModel(ModelClass.IMAGE, { ...req.body }); + const imagesRes: { image: string; caption: string }[] = []; + if (images.data && images.data.length > 0) { + for (let i = 0; i < images.data.length; i++) { + const caption = await agent.useModel( + ModelClass.IMAGE_DESCRIPTION, + images.data[i] + ); + imagesRes.push({ + image: images.data[i], + caption: caption.title, + }); + } } - runtime.registerAction(replyAction); - // for each route on each plugin, add it to the router - for (const route of runtime.routes) { - // if the path hasn't been added yet, add it - switch (route.type) { - case "GET": - this.app.get(route.path, (req: any, res: any) => route.handler(req, res)); - break; - case "POST": - this.app.post(route.path, (req: any, res: any) => route.handler(req, res)); - break; - case "PUT": - this.app.put(route.path, (req: any, res: any) => route.handler(req, res)); - break; - case "DELETE": - this.app.delete(route.path, (req: any, res: any) => route.handler(req, res)); - break; - default: - logger.error(`Unknown route type: ${route.type}`); + res.json({ images: imagesRes }); + } + ); + + this.app.post( + "/fine-tune", + async (req: express.Request, res: express.Response) => { + try { + const response = await fetch("https://api.bageldb.ai/api/v1/asset", { + method: "POST", + headers: { + "Content-Type": "application/json", + "X-API-KEY": `${process.env.BAGEL_API_KEY}`, + }, + body: JSON.stringify(req.body), + }); + + const data = await response.json(); + res.json(data); + } catch (error) { + res.status(500).json({ + error: + "Please create an account at bakery.bagel.net and get an API key. Then set the BAGEL_API_KEY environment variable.", + details: error.message, + }); + } + } + ); + + this.app.get( + "/fine-tune/:assetId", + async (req: express.Request, res: express.Response) => { + const assetId = req.params.assetId; + const downloadDir = path.join(process.cwd(), "downloads", assetId); + + logger.log("Download directory:", downloadDir); + + try { + logger.log("Creating directory..."); + await fs.promises.mkdir(downloadDir, { recursive: true }); + + logger.log("Fetching file..."); + const fileResponse = await fetch( + `https://api.bageldb.ai/api/v1/asset/${assetId}/download`, + { + headers: { + "X-API-KEY": `${process.env.BAGEL_API_KEY}`, + }, } + ); + + if (!fileResponse.ok) { + throw new Error( + `API responded with status ${ + fileResponse.status + }: ${await fileResponse.text()}` + ); + } + + logger.log("Response headers:", fileResponse.headers); + + const fileName = + fileResponse.headers + .get("content-disposition") + ?.split("filename=")[1] + ?.replace(/"/g, /* " */ "") || "default_name.txt"; + + logger.log("Saving as:", fileName); + + const arrayBuffer = await fileResponse.arrayBuffer(); + const buffer = Buffer.from(arrayBuffer); + + const filePath = path.join(downloadDir, fileName); + logger.log("Full file path:", filePath); + + await fs.promises.writeFile(filePath, buffer); + + // Verify file was written + const stats = await fs.promises.stat(filePath); + logger.log("File written successfully. Size:", stats.size, "bytes"); + + res.json({ + success: true, + message: "Single file downloaded successfully", + downloadPath: downloadDir, + fileCount: 1, + fileName: fileName, + fileSize: stats.size, + }); + } catch (error) { + logger.error("Detailed error:", error); + res.status(500).json({ + error: "Failed to download files from BagelDB", + details: error.message, + stack: error.stack, + }); } - } + } + ); + + this.app.post("/:agentId/speak", async (req, res) => { + const agentId = req.params.agentId; + const roomId = stringToUuid(req.body.roomId ?? `default-room-${agentId}`); + const userId = stringToUuid(req.body.userId ?? "user"); + const text = req.body.text; + + if (!text) { + res.status(400).send("No text provided"); + return; + } + + let runtime = this.agents.get(agentId); + + // if runtime is null, look for runtime with the same name + if (!runtime) { + runtime = Array.from(this.agents.values()).find( + (a) => a.character.name.toLowerCase() === agentId.toLowerCase() + ); + } + + if (!runtime) { + res.status(404).send("Agent not found"); + return; + } + + try { + // Process message through agent (same as /message endpoint) + await runtime.ensureConnection({ + userId, + roomId, + userName: req.body.userName, + userScreenName: req.body.name, + source: "direct", + type: ChannelType.API, + }); - public unregisterAgent(runtime: IAgentRuntime) { - this.agents.delete(runtime.agentId); - } + const messageId = stringToUuid(Date.now().toString()); - public start(port: number) { - this.server = this.app.listen(port, () => { - logger.success( - `REST API bound to 0.0.0.0:${port}. If running locally, access it at http://localhost:${port}.` - ); + const content: Content = { + text, + attachments: [], + source: "direct", + inReplyTo: undefined, + }; + + const userMessage = { + content, + userId, + roomId, + agentId: runtime.agentId, + }; + + const memory: Memory = { + id: messageId, + agentId: runtime.agentId, + userId, + roomId, + content, + createdAt: Date.now(), + }; + + await runtime.messageManager.createMemory(memory); + + const state = await runtime.composeState(userMessage, { + agentName: runtime.character.name, }); - // Handle graceful shutdown - const gracefulShutdown = () => { - logger.log("Received shutdown signal, closing server..."); - this.server.close(() => { - logger.success("Server closed successfully"); - process.exit(0); - }); + const context = composeContext({ + state, + template: messageHandlerTemplate, + }); + + const response = await generateMessageResponse({ + runtime: runtime, + context, + modelClass: ModelClass.TEXT_LARGE, + }); - // Force close after 5 seconds if server hasn't closed - setTimeout(() => { - logger.error( - "Could not close connections in time, forcefully shutting down" - ); - process.exit(1); - }, 5000); + // save response to memory + const responseMessage = { + ...userMessage, + userId: runtime.agentId, + content: response, }; - // Handle different shutdown signals - process.on("SIGTERM", gracefulShutdown); - process.on("SIGINT", gracefulShutdown); - } + await runtime.messageManager.createMemory(responseMessage); - public async stop() { - if (this.server) { - this.server.close(() => { - logger.success("Server stopped"); - }); + if (!response) { + res.status(500).send("No response from generateMessageResponse"); + return; } + + await runtime.evaluate(memory, state); + + const _result = await runtime.processActions( + memory, + [responseMessage], + state, + async () => { + return [memory]; + } + ); + + // Get the text to convert to speech + const textToSpeak = response.text; + + const speechResponse = await runtime.useModel( + ModelClass.TEXT_TO_SPEECH, + textToSpeak + ); + + if (!speechResponse.ok) { + throw new Error(`ElevenLabs API error: ${speechResponse.statusText}`); + } + + const audioBuffer = await speechResponse.arrayBuffer(); + + // Set appropriate headers for audio streaming + res.set({ + "Content-Type": "audio/mpeg", + "Transfer-Encoding": "chunked", + }); + + res.send(Buffer.from(audioBuffer)); + } catch (error) { + logger.error("Error processing message or generating speech:", error); + res.status(500).json({ + error: "Error processing message or generating speech", + details: error.message, + }); + } + }); + + this.app.post("/:agentId/tts", async (req, res) => { + const text = req.body.text; + + if (!text) { + res.status(400).send("No text provided"); + return; + } + + try { + // Convert to speech using ElevenLabs + const elevenLabsApiUrl = `https://api.elevenlabs.io/v1/text-to-speech/${process.env.ELEVENLABS_VOICE_ID}`; + const apiKey = process.env.ELEVENLABS_XI_API_KEY; + + if (!apiKey) { + throw new Error("ELEVENLABS_XI_API_KEY not configured"); + } + + // TODO: Replace the process.env with settings from the character read from the database + + const speechResponse = await fetch(elevenLabsApiUrl, { + method: "POST", + headers: { + "Content-Type": "application/json", + "xi-api-key": apiKey, + }, + body: JSON.stringify({ + text, + model_id: + process.env.ELEVENLABS_MODEL_ID || "eleven_multilingual_v2", + voice_settings: { + stability: Number.parseFloat( + process.env.ELEVENLABS_VOICE_STABILITY || "0.5" + ), + similarity_boost: Number.parseFloat( + process.env.ELEVENLABS_VOICE_SIMILARITY_BOOST || "0.9" + ), + style: Number.parseFloat( + process.env.ELEVENLABS_VOICE_STYLE || "0.66" + ), + use_speaker_boost: + process.env.ELEVENLABS_VOICE_USE_SPEAKER_BOOST === "true", + }, + }), + }); + + if (!speechResponse.ok) { + throw new Error(`ElevenLabs API error: ${speechResponse.statusText}`); + } + + const audioBuffer = await speechResponse.arrayBuffer(); + + res.set({ + "Content-Type": "audio/mpeg", + "Transfer-Encoding": "chunked", + }); + + res.send(Buffer.from(audioBuffer)); + } catch (error) { + logger.error("Error processing message or generating speech:", error); + res.status(500).json({ + error: "Error processing message or generating speech", + details: error.message, + }); + } + }); + } + + // agent/src/index.ts:startAgent calls this + public registerAgent(runtime: IAgentRuntime) { + // register any plugin endpoints? + // but once and only once + this.agents.set(runtime.agentId, runtime); + // TODO: This is a hack to register the tee plugin. Remove this once we have a better way to do it. + const teePlugin = runtime.plugins.find( + (p) => p.name === "phala-tee-plugin" + ); + if (teePlugin) { + for (const provider of teePlugin.providers) { + runtime.registerProvider(provider); + } + for (const action of teePlugin.actions) { + runtime.registerAction(action); + } + } + runtime.registerAction(replyAction); + // for each route on each plugin, add it to the router + for (const route of runtime.routes) { + // if the path hasn't been added yet, add it + switch (route.type) { + case "GET": + this.app.get(route.path, (req: any, res: any) => + route.handler(req, res) + ); + break; + case "POST": + this.app.post(route.path, (req: any, res: any) => + route.handler(req, res) + ); + break; + case "PUT": + this.app.put(route.path, (req: any, res: any) => + route.handler(req, res) + ); + break; + case "DELETE": + this.app.delete(route.path, (req: any, res: any) => + route.handler(req, res) + ); + break; + default: + logger.error(`Unknown route type: ${route.type}`); + } + } + } + + public unregisterAgent(runtime: IAgentRuntime) { + this.agents.delete(runtime.agentId); + } + + public registerMiddleware(middleware: ServerMiddleware) { + this.app.use(middleware); + } + + public start(port: number) { + this.server = this.app.listen(port, () => { + logger.success( + `REST API bound to 0.0.0.0:${port}. If running locally, access it at http://localhost:${port}.` + ); + }); + + // Handle graceful shutdown + const gracefulShutdown = () => { + logger.log("Received shutdown signal, closing server..."); + this.server.close(() => { + logger.success("Server closed successfully"); + process.exit(0); + }); + + // Force close after 5 seconds if server hasn't closed + setTimeout(() => { + logger.error( + "Could not close connections in time, forcefully shutting down" + ); + process.exit(1); + }, 5000); + }; + + // Handle different shutdown signals + process.on("SIGTERM", gracefulShutdown); + process.on("SIGINT", gracefulShutdown); + } + + public async stop() { + if (this.server) { + this.server.close(() => { + logger.success("Server stopped"); + }); } -} \ No newline at end of file + } +} From 0a5d6cb2bfb962c96c143be2609faef8cc6895bf Mon Sep 17 00:00:00 2001 From: v1xingyue Date: Mon, 24 Feb 2025 10:13:07 +0800 Subject: [PATCH 02/13] allow null options with agent server --- packages/agent/src/server/index.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/agent/src/server/index.ts b/packages/agent/src/server/index.ts index 7d9667e00a4..c41fa368f0e 100644 --- a/packages/agent/src/server/index.ts +++ b/packages/agent/src/server/index.ts @@ -50,7 +50,7 @@ export class CharacterServer { character: string | never ) => Promise; // Store jsonToCharacter function - constructor(options: ServerOptions) { + constructor(options?: ServerOptions) { logger.log("DirectClient constructor"); this.app = express(); this.app.use(cors()); @@ -59,7 +59,7 @@ export class CharacterServer { this.app.use(bodyParser.json()); this.app.use(bodyParser.urlencoded({ extended: true })); - if (options.middlewares) { + if (options?.middlewares) { for (const middleware of options.middlewares) { this.app.use(middleware); } From de077ed198e760544a07a9054efb5eb413a90ee9 Mon Sep 17 00:00:00 2001 From: AIFlow_ML Date: Mon, 24 Feb 2025 12:25:08 +0700 Subject: [PATCH 03/13] Added the TTS manager, implemented in the init a basic test that state if model produce data or not. Fixed also a precedent issue in the vision. --- packages/plugin-local-ai/src/index.ts | 66 ++++++- packages/plugin-local-ai/src/types.ts | 2 +- .../src/utils/downloadManager.ts | 18 +- .../plugin-local-ai/src/utils/ttsManager.ts | 168 +++++++++++++++--- .../src/utils/visionManager.ts | 2 +- 5 files changed, 219 insertions(+), 37 deletions(-) diff --git a/packages/plugin-local-ai/src/index.ts b/packages/plugin-local-ai/src/index.ts index 321c1831536..6c2b7fa0066 100644 --- a/packages/plugin-local-ai/src/index.ts +++ b/packages/plugin-local-ai/src/index.ts @@ -124,6 +124,14 @@ class LocalAIManager { stack: error instanceof Error ? error.stack : undefined }); return null; // Prevent Promise.all from failing completely + }), + // Add TTS initialization + this.initializeTTS().catch(error => { + logger.warn("TTS initialization failed:", { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined + }); + return null; // Prevent Promise.all from failing completely }) ]).catch(error => { logger.warn("Models initialization failed:", { @@ -241,10 +249,10 @@ class LocalAIManager { } const imageBuffer = fs.readFileSync(imagePath); - logger.info("Test image loaded:", { - size: imageBuffer.length, - path: imagePath - }); + // logger.info("Test image loaded:", { + // size: imageBuffer.length, + // path: imagePath + // }); // Process the test image const result = await this.describeImage(imageBuffer, 'image/jpeg'); @@ -257,6 +265,56 @@ class LocalAIManager { } } + private async initializeTTS(): Promise { + try { + logger.info("Initializing TTS model..."); + + // Test text for TTS + const testText = "ElizaOS is yours"; + + // Generate speech from test text + logger.info("Testing TTS with sample text:", { text: testText }); + const audioStream = await this.ttsManager.generateSpeech(testText); + + // Verify the stream is readable + if (!(audioStream instanceof Readable)) { + throw new Error("TTS did not return a valid audio stream"); + } + + // Test stream readability + let dataReceived = false; + await new Promise((resolve, reject) => { + audioStream.on('data', () => { + if (!dataReceived) { + dataReceived = true; + logger.info("TTS audio stream is producing data"); + } + }); + + audioStream.on('end', () => { + if (!dataReceived) { + reject(new Error("No audio data received from TTS stream")); + } else { + resolve(); + } + }); + + audioStream.on('error', (err) => { + reject(err); + }); + }); + + logger.success("TTS model initialization complete"); + } catch (error) { + logger.error("TTS initialization failed:", { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + timestamp: new Date().toISOString() + }); + throw error; + } + } + private async downloadModel(): Promise { try { // Determine which model to download based on current modelPath diff --git a/packages/plugin-local-ai/src/types.ts b/packages/plugin-local-ai/src/types.ts index 75607240026..3831c06cf1e 100644 --- a/packages/plugin-local-ai/src/types.ts +++ b/packages/plugin-local-ai/src/types.ts @@ -112,7 +112,7 @@ export const MODEL_SPECS: ModelSpecs = { }, tts: { base: { - name: "OuteTTS-0.2-500M.gguf", + name: "OuteTTS-0.2-500M-Q8_0.gguf", repo: "OuteAI/OuteTTS-0.2-500M-GGUF", size: "500M", quantization: "Q8_0", diff --git a/packages/plugin-local-ai/src/utils/downloadManager.ts b/packages/plugin-local-ai/src/utils/downloadManager.ts index b57302913d9..4bfdd9053ab 100644 --- a/packages/plugin-local-ai/src/utils/downloadManager.ts +++ b/packages/plugin-local-ai/src/utils/downloadManager.ts @@ -124,8 +124,22 @@ export class DownloadManager { } if (!fs.existsSync(modelPath)) { - const downloadUrl = `https://huggingface.co/${modelSpec.repo}/resolve/main/${modelSpec.name}`; - logger.info("Download URL:", downloadUrl); + // For GGUF models, we need to adjust the repo path + const repoPath = modelSpec.repo.replace('-GGUF', ''); + const downloadUrl = `https://huggingface.co/${repoPath}/resolve/main/${modelSpec.name}`; + logger.info("Model download details:", { + originalRepo: modelSpec.repo, + adjustedRepo: repoPath, + modelName: modelSpec.name, + downloadUrl, + alternativeUrls: { + withGGUF: `https://huggingface.co/${modelSpec.repo}/resolve/main/${modelSpec.name}`, + rawUrl: `https://huggingface.co/${modelSpec.repo}/blob/main/${modelSpec.name}`, + lfsUrl: `https://huggingface.co/${modelSpec.repo}/resolve/main/${modelSpec.name}?download=true` + }, + modelPath: modelPath, + timestamp: new Date().toISOString() + }); await this.downloadFile(downloadUrl, modelPath); logger.success(`Model download complete: ${modelSpec.name}`); diff --git a/packages/plugin-local-ai/src/utils/ttsManager.ts b/packages/plugin-local-ai/src/utils/ttsManager.ts index 1fa1c605923..db698ac7c5c 100644 --- a/packages/plugin-local-ai/src/utils/ttsManager.ts +++ b/packages/plugin-local-ai/src/utils/ttsManager.ts @@ -29,6 +29,11 @@ export class TTSManager { : path.join(process.cwd(), "models"); this.downloadManager = DownloadManager.getInstance(this.cacheDir); this.ensureCacheDirectory(); + logger.info("TTSManager initialized with configuration:", { + cacheDir: this.cacheDir, + modelsDir: this.modelsDir, + timestamp: new Date().toISOString() + }); } public static getInstance(cacheDir: string): TTSManager { @@ -53,27 +58,96 @@ export class TTSManager { logger.info("Initializing TTS with GGUF backend..."); - // Download the model if needed - const modelPath = path.join(this.modelsDir, MODEL_SPECS.tts.base.name); - await this.downloadManager.downloadModel(MODEL_SPECS.tts.base, modelPath); + const modelSpec = MODEL_SPECS.tts.base; + const modelPath = path.join(this.modelsDir, modelSpec.name); + // Log detailed model configuration and paths + logger.info("TTS model configuration:", { + name: modelSpec.name, + repo: modelSpec.repo, + modelPath, + timestamp: new Date().toISOString() + }); + + if (!fs.existsSync(modelPath)) { + // Try different URL patterns in sequence + const attempts = [ + { + spec: { ...modelSpec }, + description: "Standard URL with GGUF", + url: `https://huggingface.co/${modelSpec.repo}/resolve/main/${modelSpec.name}?download=true` + }, + { + spec: { ...modelSpec, repo: modelSpec.repo.replace('-GGUF', '') }, + description: "URL without GGUF suffix", + url: `https://huggingface.co/${modelSpec.repo.replace('-GGUF', '')}/resolve/main/${modelSpec.name}?download=true` + }, + { + spec: { ...modelSpec, name: modelSpec.name.replace('-Q8_0', '') }, + description: "URL without quantization suffix", + url: `https://huggingface.co/${modelSpec.repo}/resolve/main/${modelSpec.name.replace('-Q8_0', '')}.gguf?download=true` + } + ]; + + let lastError = null; + for (const attempt of attempts) { + try { + logger.info("Attempting TTS model download:", { + description: attempt.description, + repo: attempt.spec.repo, + name: attempt.spec.name, + url: attempt.url, + timestamp: new Date().toISOString() + }); + + const barLength = 20; + const progressBar = 'â–ˆ'.repeat(barLength); + logger.info(`TTS model download: ${progressBar} Starting...`); + + await this.downloadManager.downloadFromUrl(attempt.url, modelPath); + + logger.info(`TTS model download: ${progressBar} 100%`); + logger.success("TTS model download successful with:", attempt.description); + break; + } catch (error) { + lastError = error; + logger.warn("TTS model download attempt failed:", { + description: attempt.description, + error: error instanceof Error ? error.message : String(error), + timestamp: new Date().toISOString() + }); + } + } + + if (!fs.existsSync(modelPath)) { + throw lastError || new Error("All download attempts failed"); + } + } + + logger.info("Loading TTS model..."); const llama = await getLlama(); this.model = await llama.loadModel({ - modelPath + modelPath, + gpuLayers: 0 // Force CPU for now until we add GPU support }); this.ctx = await this.model.createContext({ - contextSize: MODEL_SPECS.tts.base.contextSize + contextSize: modelSpec.contextSize }); this.sequence = this.ctx.getSequence(); - logger.success("TTS initialization complete"); + logger.success("TTS initialization complete", { + modelPath, + contextSize: modelSpec.contextSize, + timestamp: new Date().toISOString() + }); this.initialized = true; } catch (error) { logger.error("TTS initialization failed:", { error: error instanceof Error ? error.message : String(error), - model: MODEL_SPECS.tts.base.name + model: MODEL_SPECS.tts.base.name, + timestamp: new Date().toISOString() }); throw error; } @@ -87,48 +161,84 @@ export class TTSManager { throw new Error("TTS model not initialized"); } - logger.info("Generating speech for text:", { length: text.length }); + logger.info("Starting speech generation for text:", { text }); // Format prompt for TTS generation - const prompt = `[SPEAKER=male_1][LANGUAGE=en]${text}`; + const prompt = `[SPEAKER=female_1][LANGUAGE=en]${text}`; + logger.info("Formatted prompt:", { prompt }); - // Tokenize the input text + // Tokenize input + logger.info("Tokenizing input..."); const inputTokens = this.model.tokenize(prompt); - - // Generate audio tokens + logger.info("Input tokenized:", { tokenCount: inputTokens.length }); + + // Generate audio tokens with optimized limit (2x input) + const maxTokens = inputTokens.length * 2; + logger.info("Starting token generation with optimized limit:", { maxTokens }); const responseTokens: Token[] = []; - for await (const token of this.sequence.evaluate(inputTokens, { - temperature: 0.1, - repeatPenalty: { - punishTokens: () => [], - penalty: 1.0, - frequencyPenalty: 0.0, - presencePenalty: 0.0 + const startTime = Date.now(); + + try { + for await (const token of this.sequence.evaluate(inputTokens, { + temperature: 0.1, + + })) { + responseTokens.push(token); + + // Update progress bar + const progress = Math.round((responseTokens.length / maxTokens) * 100); + const barLength = 20; + const filledLength = Math.floor((progress / 100) * barLength); + const bar = 'â–ˆ'.repeat(filledLength) + 'â–‘'.repeat(barLength - filledLength); + logger.info(`Token generation: ${bar} ${progress}% (${responseTokens.length}/${maxTokens})`); + + // Stop if we hit our token limit + if (responseTokens.length >= maxTokens) { + logger.info("Token generation complete"); + break; + } } - })) { - responseTokens.push(token); + } catch (error) { + logger.error("Token generation error:", error); + throw error; } - // Convert tokens to audio data - if (!this.model) { - throw new Error("Model not initialized"); + logger.info("Token generation stats:", { + inputTokens: inputTokens.length, + outputTokens: responseTokens.length, + timeMs: Date.now() - startTime + }); + + if (responseTokens.length === 0) { + throw new Error("No audio tokens generated"); } + + // Convert tokens to audio data + logger.info("Converting tokens to audio data..."); const audioData = this.processAudioResponse({ tokens: responseTokens.map(t => Number.parseInt(this.model.detokenize([t]), 10)) }); - // Create WAV format with proper headers - return prependWavHeader( + logger.info("Audio data generated:", { + byteLength: audioData.length, + sampleRate: MODEL_SPECS.tts.base.sampleRate + }); + + // Create WAV format + const audioStream = prependWavHeader( Readable.from(audioData), audioData.length, MODEL_SPECS.tts.base.sampleRate, - 1, // mono channel - 16 // 16-bit PCM + 1, + 16 ); + + logger.success("Speech generation complete"); + return audioStream; } catch (error) { logger.error("Speech generation failed:", { error: error instanceof Error ? error.message : String(error), - textLength: text.length + text }); throw error; } diff --git a/packages/plugin-local-ai/src/utils/visionManager.ts b/packages/plugin-local-ai/src/utils/visionManager.ts index f639692929e..3de4140a71f 100644 --- a/packages/plugin-local-ai/src/utils/visionManager.ts +++ b/packages/plugin-local-ai/src/utils/visionManager.ts @@ -318,7 +318,7 @@ export class VisionManager { private async fetchImage(url: string): Promise<{ buffer: Buffer; mimeType: string }> { try { - logger.info("Fetching image from URL:", { url }); + // logger.info("Fetching image from URL:", { url }); // Handle data URLs differently if (url.startsWith('data:')) { From c205a7c540b573f22ead67e14b6cb1f5dc2ab16b Mon Sep 17 00:00:00 2001 From: v1xingyue Date: Mon, 24 Feb 2025 14:07:08 +0800 Subject: [PATCH 04/13] Update index.ts --- packages/agent/src/server/index.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/agent/src/server/index.ts b/packages/agent/src/server/index.ts index 996d2ebbf64..0b4bd72a69d 100644 --- a/packages/agent/src/server/index.ts +++ b/packages/agent/src/server/index.ts @@ -439,7 +439,7 @@ export class AgentServer { } } if (hfOut.emote !== null) { - contentObj.text = `emoted ${hfOut.emote}`; + contentObj.text += `emoted ${hfOut.emote}`; } } From d3d8c0495e4540d236b8a87116f2bc981df40138 Mon Sep 17 00:00:00 2001 From: AIFlow_ML Date: Mon, 24 Feb 2025 16:27:16 +0700 Subject: [PATCH 05/13] Added Ollama and StudioLM managers. Added initial envirnoment check - Ollama pefroermance very bad compared to StudioLM --- .env.example | 25 +- packages/plugin-local-ai/src/environment.ts | 80 +++++ packages/plugin-local-ai/src/index.ts | 240 ++++++++++++-- .../src/utils/ollamaManager.ts | 292 ++++++++++++++++ .../src/utils/studiolmManager.ts | 311 ++++++++++++++++++ 5 files changed, 925 insertions(+), 23 deletions(-) create mode 100644 packages/plugin-local-ai/src/environment.ts create mode 100644 packages/plugin-local-ai/src/utils/ollamaManager.ts create mode 100644 packages/plugin-local-ai/src/utils/studiolmManager.ts diff --git a/.env.example b/.env.example index ffde56e3d24..114adbb11e6 100644 --- a/.env.example +++ b/.env.example @@ -28,4 +28,27 @@ BIRDEYE_API_KEY= # Swarm settings COMPLIANCE_OFFICER_DISCORD_APPLICATION_ID= -COMPLIANCE_OFFICER_DISCORD_API_TOKEN= \ No newline at end of file +COMPLIANCE_OFFICER_DISCORD_API_TOKEN= + + +# ================================== +# Local AI Configuration +# ================================== +USE_LOCAL_AI=true +USE_STUDIOLM_TEXT_MODELS=false +USE_OLLAMA_TEXT_MODELS=false + +# Ollama Configuration +OLLAMA_SERVER_URL=http://localhost:11434 +OLLAMA_MODEL=deepseek-r1-distill-qwen-7b +USE_OLLAMA_EMBEDDING=false +OLLAMA_EMBEDDING_MODEL= +SMALL_OLLAMA_MODEL=deepseek-r1:1.5b +MEDIUM_OLLAMA_MODEL=deepseek-r1:7b +LARGE_OLLAMA_MODEL=deepseek-r1:7b + +# StudioLM Configuration +STUDIOLM_SERVER_URL=http://localhost:1234 +STUDIOLM_SMALL_MODEL=lmstudio-community/deepseek-r1-distill-qwen-1.5b +STUDIOLM_MEDIUM_MODEL=deepseek-r1-distill-qwen-7b +STUDIOLM_EMBEDDING_MODEL=false \ No newline at end of file diff --git a/packages/plugin-local-ai/src/environment.ts b/packages/plugin-local-ai/src/environment.ts new file mode 100644 index 00000000000..eaec9981059 --- /dev/null +++ b/packages/plugin-local-ai/src/environment.ts @@ -0,0 +1,80 @@ +import type { IAgentRuntime } from "@elizaos/core"; +import { z } from "zod"; +import { logger } from "@elizaos/core"; + +// Configuration schema with text model source flags +export const configSchema = z.object({ + USE_LOCAL_AI: z.boolean().default(true), + USE_STUDIOLM_TEXT_MODELS: z.boolean().default(false), + USE_OLLAMA_TEXT_MODELS: z.boolean().default(false), +}); + +export type Config = z.infer; + +function validateModelConfig(config: Record): void { + // Log raw values before validation + logger.info("Validating model configuration with values:", { + USE_LOCAL_AI: config.USE_LOCAL_AI, + USE_STUDIOLM_TEXT_MODELS: config.USE_STUDIOLM_TEXT_MODELS, + USE_OLLAMA_TEXT_MODELS: config.USE_OLLAMA_TEXT_MODELS + }); + + // Ensure USE_LOCAL_AI is always true + if (!config.USE_LOCAL_AI) { + config.USE_LOCAL_AI = true; + logger.info("Setting USE_LOCAL_AI to true as it's required"); + } + + // Only validate that StudioLM and Ollama are not both enabled + if (config.USE_STUDIOLM_TEXT_MODELS && config.USE_OLLAMA_TEXT_MODELS) { + throw new Error("StudioLM and Ollama text models cannot be enabled simultaneously"); + } + + logger.info("Configuration is valid"); +} + +export async function validateConfig( + config: Record +): Promise { + try { + // Log raw environment variables + logger.info("Raw environment variables:", { + USE_LOCAL_AI: process.env.USE_LOCAL_AI, + USE_STUDIOLM_TEXT_MODELS: process.env.USE_STUDIOLM_TEXT_MODELS, + USE_OLLAMA_TEXT_MODELS: process.env.USE_OLLAMA_TEXT_MODELS + }); + + // Parse environment variables with proper boolean conversion + const booleanConfig = { + USE_LOCAL_AI: true, // Always true + USE_STUDIOLM_TEXT_MODELS: config.USE_STUDIOLM_TEXT_MODELS === 'true', + USE_OLLAMA_TEXT_MODELS: config.USE_OLLAMA_TEXT_MODELS === 'true', + }; + + logger.info("Parsed boolean configuration:", booleanConfig); + + // Validate text model source configuration + validateModelConfig(booleanConfig); + + const validatedConfig = configSchema.parse(booleanConfig); + + logger.info("Final validated configuration:", validatedConfig); + + return validatedConfig; + } catch (error) { + if (error instanceof z.ZodError) { + const errorMessages = error.errors + .map((err) => `${err.path.join(".")}: ${err.message}`) + .join("\n"); + logger.error("Zod validation failed:", errorMessages); + throw new Error( + `Configuration validation failed:\n${errorMessages}` + ); + } + logger.error("Configuration validation failed:", { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined + }); + throw error; + } +} \ No newline at end of file diff --git a/packages/plugin-local-ai/src/index.ts b/packages/plugin-local-ai/src/index.ts index 6c2b7fa0066..7b79075b8eb 100644 --- a/packages/plugin-local-ai/src/index.ts +++ b/packages/plugin-local-ai/src/index.ts @@ -27,16 +27,19 @@ import { DownloadManager } from './utils/downloadManager'; import { VisionManager } from './utils/visionManager'; import { TranscribeManager } from './utils/transcribeManager'; import { TTSManager } from './utils/ttsManager'; +import { StudioLMManager } from './utils/studiolmManager'; +import { OllamaManager } from './utils/ollamaManager'; +import { validateConfig } from "./environment"; // const execAsync = promisify(exec); const __filename = fileURLToPath(import.meta.url); const __dirname = path.dirname(__filename); // Configuration schema -const configSchema = z.object({ - LLAMALOCAL_PATH: z.string().optional(), - CACHE_DIR: z.string().optional().default("./cache"), -}); +// const configSchema = z.object({ +// LLAMALOCAL_PATH: z.string().optional(), +// CACHE_DIR: z.string().optional().default("./cache"), +// }); // Words to punish in LLM responses const wordsToPunish = [ @@ -49,6 +52,13 @@ const wordsToPunish = [ " algorithm", " Indeed", " Furthermore", " However", " Notably", " Therefore" ]; +// Add type definitions for model source selection +type TextModelSource = 'local' | 'studiolm' | 'ollama'; + +interface TextModelConfig { + source: TextModelSource; + modelClass: ModelClass; +} class LocalAIManager { private static instance: LocalAIManager | null = null; @@ -68,6 +78,10 @@ class LocalAIManager { private activeModelConfig: ModelSpec; private transcribeManager: TranscribeManager; private ttsManager: TTSManager; + private studioLMManager: StudioLMManager; + private ollamaManager: OllamaManager; + private ollamaInitialized = false; + private studioLMInitialized = false; private constructor() { // Ensure we have a valid models directory @@ -97,6 +111,14 @@ class LocalAIManager { this.visionManager = VisionManager.getInstance(this.cacheDir); this.transcribeManager = TranscribeManager.getInstance(this.cacheDir); this.ttsManager = TTSManager.getInstance(this.cacheDir); + this.studioLMManager = StudioLMManager.getInstance(); + this.ollamaManager = OllamaManager.getInstance(); + + // Initialize environment + this.initializeEnvironment().catch(error => { + logger.error("Environment initialization failed:", error); + throw error; + }); // Add platform capabilities check in constructor this.checkPlatformCapabilities().catch(error => { @@ -132,7 +154,27 @@ class LocalAIManager { stack: error instanceof Error ? error.stack : undefined }); return null; // Prevent Promise.all from failing completely - }) + }), + // Add StudioLM initialization + this.initializeStudioLM().then(() => { + this.studioLMInitialized = true; + }).catch(error => { + logger.warn("StudioLM initialization failed:", { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined + }); + return null; // Prevent Promise.all from failing completely + }), + // Add Ollama initialization + this.initializeOllama().then(() => { + this.ollamaInitialized = true; + }).catch(error => { + logger.warn("Ollama initialization failed:", { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined + }); + return null; // Prevent Promise.all from failing completely + }), ]).catch(error => { logger.warn("Models initialization failed:", { error: error instanceof Error ? error.message : String(error), @@ -148,6 +190,80 @@ class LocalAIManager { return LocalAIManager.instance; } + private async initializeEnvironment(): Promise { + try { + logger.info("Validating environment configuration..."); + + // Create initial config from current env vars + const config = { + USE_LOCAL_AI: process.env.USE_LOCAL_AI, + USE_STUDIOLM_TEXT_MODELS: process.env.USE_STUDIOLM_TEXT_MODELS, + USE_OLLAMA_TEXT_MODELS: process.env.USE_OLLAMA_TEXT_MODELS + }; + + // Validate configuration + const validatedConfig = await validateConfig(config); + + // Log the validated configuration + logger.info("Environment configuration validated:", validatedConfig); + + + logger.success("Environment initialization complete"); + } catch (error) { + logger.error("Environment validation failed:", { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined + }); + throw error; + } + } + + private async initializeOllama(): Promise { + try { + logger.info("Initializing Ollama models..."); + this.ollamaManager = OllamaManager.getInstance(); + + // Initialize and test models + await this.ollamaManager.initialize(); + + if (!this.ollamaManager.isInitialized()) { + throw new Error("Ollama initialization failed - models not properly loaded"); + } + + logger.success("Ollama initialization complete"); + } catch (error) { + logger.error("Ollama initialization failed:", { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + timestamp: new Date().toISOString() + }); + throw error; + } + } + + private async initializeStudioLM(): Promise { + try { + logger.info("Initializing StudioLM models..."); + + // Initialize and test models + await this.studioLMManager.initialize(); + + if (!this.studioLMManager.isInitialized()) { + throw new Error("StudioLM initialization failed - models not properly loaded"); + } + + this.studioLMInitialized = true; + logger.success("StudioLM initialization complete"); + } catch (error) { + logger.error("StudioLM initialization failed:", { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + timestamp: new Date().toISOString() + }); + throw error; + } + } + private async initializeTranscription(): Promise { try { logger.info("Initializing transcription model..."); @@ -249,10 +365,6 @@ class LocalAIManager { } const imageBuffer = fs.readFileSync(imagePath); - // logger.info("Test image loaded:", { - // size: imageBuffer.length, - // path: imagePath - // }); // Process the test image const result = await this.describeImage(imageBuffer, 'image/jpeg'); @@ -419,7 +531,47 @@ class LocalAIManager { } } - + async generateTextOllamaStudio(params: GenerateTextParams): Promise { + try { + const modelConfig = this.getTextModelSource(); + logger.info("generateTextOllamaStudio called with:", { + modelSource: modelConfig.source, + modelClass: params.modelClass, + studioLMInitialized: this.studioLMInitialized, + studioLMManagerInitialized: this.studioLMManager.isInitialized() + }); + + if (modelConfig.source === 'studiolm') { + // Only initialize if not already initialized + if (!this.studioLMInitialized) { + logger.info("StudioLM not initialized, initializing now..."); + await this.initializeStudioLM(); + } + + // Pass initialization flag to generateText + return await this.studioLMManager.generateText(params, this.studioLMInitialized); + } + + if (modelConfig.source === 'ollama') { + // Only initialize if not already initialized + if (!this.ollamaInitialized && !this.ollamaManager.isInitialized()) { + logger.info("Initializing Ollama in generateTextOllamaStudio"); + await this.ollamaManager.initialize(); + this.ollamaInitialized = true; + } + + // Pass initialization flag to generateText + return await this.ollamaManager.generateText(params, this.ollamaInitialized); + } + + // Fallback to local models if something goes wrong + return this.generateText(params); + } catch (error) { + logger.error("Text generation with Ollama/StudioLM failed:", error); + // Fallback to local models + return this.generateText(params); + } + } async generateText(params: GenerateTextParams): Promise { try { @@ -597,6 +749,30 @@ class LocalAIManager { public getActiveModelConfig(): ModelSpec { return this.activeModelConfig; } + + public getTextModelSource(): TextModelConfig { + try { + // Default configuration + const config: TextModelConfig = { + source: 'local', + modelClass: ModelClass.TEXT_SMALL + }; + + // Check environment configuration + if (process.env.USE_STUDIOLM_TEXT_MODELS === 'true') { + config.source = 'studiolm'; + } else if (process.env.USE_OLLAMA_TEXT_MODELS === 'true') { + config.source = 'ollama'; + } + + logger.info("Selected text model source:", config); + return config; + } catch (error) { + logger.error("Error determining text model source:", error); + // Fallback to local models + return { source: 'local', modelClass: ModelClass.TEXT_SMALL }; + } + } } // Create manager instance @@ -609,22 +785,20 @@ export const localAIPlugin: Plugin = { async init(config: Record) { try { logger.info("Initializing local-ai plugin..."); - const validatedConfig = await configSchema.parseAsync(config); + const validatedConfig = await validateConfig(config); // Set environment variables for (const [key, value] of Object.entries(validatedConfig)) { - if (value) { - process.env[key] = value; - logger.debug(`Set ${key}=${value}`); - } + process.env[key] = String(value); // Convert boolean to string + logger.debug(`Set ${key}=${value}`); } + + logger.success("Local AI plugin configuration validated and initialized"); } catch (error) { - if (error instanceof z.ZodError) { - throw new Error( - `Invalid plugin configuration: ${error.errors.map((e) => e.message).join(", ")}` - ); - } - logger.error("Plugin initialization failed:", error); + logger.error("Plugin initialization failed:", { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined + }); throw error; } }, @@ -632,6 +806,17 @@ export const localAIPlugin: Plugin = { models: { [ModelClass.TEXT_SMALL]: async (runtime: IAgentRuntime, { context, stopSequences = [] }: GenerateTextParams) => { try { + const modelConfig = localAIManager.getTextModelSource(); + + if (modelConfig.source !== 'local') { + return await localAIManager.generateTextOllamaStudio({ + context, + stopSequences, + runtime, + modelClass: ModelClass.TEXT_SMALL + }); + } + return await localAIManager.generateText({ context, stopSequences, @@ -646,6 +831,17 @@ export const localAIPlugin: Plugin = { [ModelClass.TEXT_LARGE]: async (runtime: IAgentRuntime, { context, stopSequences = [] }: GenerateTextParams) => { try { + const modelConfig = localAIManager.getTextModelSource(); + + if (modelConfig.source !== 'local') { + return await localAIManager.generateTextOllamaStudio({ + context, + stopSequences, + runtime, + modelClass: ModelClass.TEXT_LARGE + }); + } + return await localAIManager.generateText({ context, stopSequences, @@ -1049,7 +1245,7 @@ export const localAIPlugin: Plugin = { { path: "/health", type: "GET", - handler: async (req: any, res: any) => { + handler: async (_req: unknown, res: { json: (data: unknown) => void }) => { res.json({ status: "healthy", models: { diff --git a/packages/plugin-local-ai/src/utils/ollamaManager.ts b/packages/plugin-local-ai/src/utils/ollamaManager.ts new file mode 100644 index 00000000000..8312b6900eb --- /dev/null +++ b/packages/plugin-local-ai/src/utils/ollamaManager.ts @@ -0,0 +1,292 @@ +import { logger, ModelClass, type GenerateTextParams } from "@elizaos/core"; +import fetch from "node-fetch"; + +interface OllamaModel { + name: string; + id: string; + size: string; + modified: string; +} + +interface OllamaResponse { + model: string; + response: string; + done: boolean; + context?: number[]; + total_duration?: number; + load_duration?: number; + prompt_eval_duration?: number; + eval_duration?: number; +} + +export class OllamaManager { + private static instance: OllamaManager | null = null; + private serverUrl: string; + private initialized = false; + private availableModels: OllamaModel[] = []; + private configuredModels = { + small: "deepseek-r1:1.5b", + medium: "deepseek-r1:7b" + }; + + private constructor() { + this.serverUrl = process.env.OLLAMA_SERVER_URL || "http://localhost:11434"; + logger.info("OllamaManager initialized with configuration:", { + serverUrl: this.serverUrl, + configuredModels: this.configuredModels, + timestamp: new Date().toISOString() + }); + } + + public static getInstance(): OllamaManager { + if (!OllamaManager.instance) { + OllamaManager.instance = new OllamaManager(); + } + return OllamaManager.instance; + } + + private async checkServerStatus(): Promise { + try { + const response = await fetch(`${this.serverUrl}/api/tags`); + if (!response.ok) { + throw new Error(`Server responded with status: ${response.status}`); + } + return true; + } catch (error) { + logger.error("Ollama server check failed:", { + error: error instanceof Error ? error.message : String(error), + serverUrl: this.serverUrl, + timestamp: new Date().toISOString() + }); + return false; + } + } + + private async fetchAvailableModels(): Promise { + try { + const response = await fetch(`${this.serverUrl}/api/tags`); + if (!response.ok) { + throw new Error(`Failed to fetch models: ${response.status}`); + } + + const data = await response.json() as { models: OllamaModel[] }; + this.availableModels = data.models; + + logger.info("Ollama available models:", { + count: this.availableModels.length, + models: this.availableModels.map(m => m.name), + timestamp: new Date().toISOString() + }); + } catch (error) { + logger.error("Failed to fetch Ollama models:", { + error: error instanceof Error ? error.message : String(error), + serverUrl: this.serverUrl, + timestamp: new Date().toISOString() + }); + throw error; + } + } + + private async testModel(modelId: string): Promise { + try { + const testRequest = { + model: modelId, + prompt: "Debug Mode: Test initialization. Respond with 'Initialization successful' if you can read this.", + stream: false, + options: { + temperature: 0.7, + num_predict: 100 + } + }; + + logger.info(`Testing model ${modelId}...`); + + const response = await fetch(`${this.serverUrl}/api/generate`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(testRequest) + }); + + if (!response.ok) { + throw new Error(`Model test failed with status: ${response.status}`); + } + + const result = await response.json() as OllamaResponse; + + if (!result.response) { + throw new Error("No valid response content received"); + } + + logger.info(`Model ${modelId} test response:`, { + content: result.response, + model: result.model, + timestamp: new Date().toISOString() + }); + + return true; + } catch (error) { + logger.error(`Model ${modelId} test failed:`, { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + timestamp: new Date().toISOString() + }); + return false; + } + } + + private async testTextModels(): Promise { + logger.info("Testing configured text models..."); + + const results = await Promise.all([ + this.testModel(this.configuredModels.small), + this.testModel(this.configuredModels.medium) + ]); + + const [smallWorking, mediumWorking] = results; + + if (!smallWorking || !mediumWorking) { + const failedModels = []; + if (!smallWorking) failedModels.push("small"); + if (!mediumWorking) failedModels.push("medium"); + + logger.warn("Some models failed the test:", { + failedModels, + small: this.configuredModels.small, + medium: this.configuredModels.medium + }); + } else { + logger.success("All configured models passed the test"); + } + } + + public async initialize(): Promise { + try { + if (this.initialized) { + logger.info("Ollama already initialized, skipping initialization"); + return; + } + + logger.info("Starting Ollama initialization..."); + const serverAvailable = await this.checkServerStatus(); + + if (!serverAvailable) { + throw new Error("Ollama server is not available"); + } + + await this.fetchAvailableModels(); + await this.testTextModels(); + + this.initialized = true; + logger.success("Ollama initialization complete"); + } catch (error) { + logger.error("Ollama initialization failed:", { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined + }); + throw error; + } + } + + public getAvailableModels(): OllamaModel[] { + return this.availableModels; + } + + public isInitialized(): boolean { + return this.initialized; + } + + public async generateText(params: GenerateTextParams, isInitialized = false): Promise { + try { + // Log entry point with all parameters + logger.info("Ollama generateText entry:", { + isInitialized, + currentInitState: this.initialized, + managerInitState: this.isInitialized(), + modelClass: params.modelClass, + contextLength: params.context?.length, + timestamp: new Date().toISOString() + }); + + // Only initialize if not already initialized and not marked as initialized + if (!this.initialized && !isInitialized) { + throw new Error("Ollama not initialized. Please initialize before generating text."); + } + + logger.info("Ollama preparing request:", { + model: params.modelClass === ModelClass.TEXT_LARGE ? + this.configuredModels.medium : + this.configuredModels.small, + contextLength: params.context.length, + timestamp: new Date().toISOString() + }); + + const request = { + model: params.modelClass === ModelClass.TEXT_LARGE ? + this.configuredModels.medium : + this.configuredModels.small, + prompt: params.context, + stream: false, + options: { + temperature: 0.7, + top_p: 0.9, + num_predict: 8192, + repeat_penalty: 1.2, + frequency_penalty: 0.7, + presence_penalty: 0.7 + } + }; + + const response = await fetch(`${this.serverUrl}/api/generate`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(request) + }); + + if (!response.ok) { + throw new Error(`Ollama request failed: ${response.status}`); + } + + const result = await response.json() as OllamaResponse; + + if (!result.response) { + throw new Error("No valid response content received from Ollama"); + } + + let responseText = result.response; + + // Log raw response for debugging + logger.info("Raw response structure:", { + responseLength: responseText.length, + hasAction: responseText.includes("action"), + hasThinkTag: responseText.includes("") + }); + + // Clean think tags if present + if (responseText.includes("")) { + logger.info("Cleaning think tags from response"); + responseText = responseText.replace(/[\s\S]*?<\/think>\n?/g, ""); + logger.info("Think tags removed from response"); + } + + logger.info("Ollama request completed successfully:", { + responseLength: responseText.length, + hasThinkTags: responseText.includes(""), + timestamp: new Date().toISOString() + }); + + return responseText; + } catch (error) { + logger.error("Ollama text generation error:", { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + phase: "text generation", + timestamp: new Date().toISOString() + }); + throw error; + } + } +} \ No newline at end of file diff --git a/packages/plugin-local-ai/src/utils/studiolmManager.ts b/packages/plugin-local-ai/src/utils/studiolmManager.ts new file mode 100644 index 00000000000..73aa505d466 --- /dev/null +++ b/packages/plugin-local-ai/src/utils/studiolmManager.ts @@ -0,0 +1,311 @@ +import { logger, ModelClass, type GenerateTextParams } from "@elizaos/core"; +import fetch from "node-fetch"; + +interface StudioLMModel { + id: string; + object: string; + created: number; + owned_by: string; +} + +interface ChatMessage { + role: "system" | "user" | "assistant"; + content: string; +} + +interface ChatCompletionRequest { + model: string; + messages: ChatMessage[]; + temperature?: number; + max_tokens?: number; + stream?: boolean; +} + +interface ChatCompletionResponse { + id: string; + object: string; + created: number; + model: string; + choices: { + index: number; + message: ChatMessage; + finish_reason: string; + }[]; +} + +export class StudioLMManager { + private static instance: StudioLMManager | null = null; + private serverUrl: string; + private initialized = false; + private availableModels: StudioLMModel[] = []; + private configuredModels = { + small: "lmstudio-community/deepseek-r1-distill-qwen-1.5b", + medium: "deepseek-r1-distill-qwen-7b" + }; + + private constructor() { + this.serverUrl = process.env.STUDIOLM_SERVER_URL || "http://localhost:1234"; + logger.info("StudioLMManager initialized with configuration:", { + serverUrl: this.serverUrl, + configuredModels: this.configuredModels, + timestamp: new Date().toISOString() + }); + } + + public static getInstance(): StudioLMManager { + if (!StudioLMManager.instance) { + StudioLMManager.instance = new StudioLMManager(); + } + return StudioLMManager.instance; + } + + private async checkServerStatus(): Promise { + try { + const response = await fetch(`${this.serverUrl}/v1/models`); + if (!response.ok) { + throw new Error(`Server responded with status: ${response.status}`); + } + return true; + } catch (error) { + logger.error("LM Studio server check failed:", { + error: error instanceof Error ? error.message : String(error), + serverUrl: this.serverUrl, + timestamp: new Date().toISOString() + }); + return false; + } + } + + private async fetchAvailableModels(): Promise { + try { + const response = await fetch(`${this.serverUrl}/v1/models`); + if (!response.ok) { + throw new Error(`Failed to fetch models: ${response.status}`); + } + + const data = await response.json() as { data: StudioLMModel[] }; + this.availableModels = data.data; + + logger.info("LM Studio available models:", { + count: this.availableModels.length, + models: this.availableModels.map(m => m.id), + timestamp: new Date().toISOString() + }); + } catch (error) { + logger.error("Failed to fetch LM Studio models:", { + error: error instanceof Error ? error.message : String(error), + serverUrl: this.serverUrl, + timestamp: new Date().toISOString() + }); + throw error; + } + } + + private async testModel(modelId: string): Promise { + try { + const testRequest: ChatCompletionRequest = { + model: modelId, + messages: [ + { role: "system", content: "Always answer in rhymes. Today is Thursday" }, + { role: "user", content: "What day is it today?" } + ], + temperature: 0.7, + max_tokens: -1, + stream: false + }; + + logger.info(`Testing model ${modelId}...`); + + const response = await fetch(`${this.serverUrl}/v1/chat/completions`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(testRequest) + }); + + if (!response.ok) { + throw new Error(`Model test failed with status: ${response.status}`); + } + + const result = await response.json() as ChatCompletionResponse; + + if (!result.choices?.[0]?.message?.content) { + throw new Error("No valid response content received"); + } + + logger.info(`Model ${modelId} test response:`, { + content: result.choices[0].message.content, + model: result.model, + timestamp: new Date().toISOString() + }); + + return true; + } catch (error) { + logger.error(`Model ${modelId} test failed:`, { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + timestamp: new Date().toISOString() + }); + return false; + } + } + + private async testTextModels(): Promise { + logger.info("Testing configured text models..."); + + const results = await Promise.all([ + this.testModel(this.configuredModels.small), + this.testModel(this.configuredModels.medium) + ]); + + const [smallWorking, mediumWorking] = results; + + if (!smallWorking || !mediumWorking) { + const failedModels = []; + if (!smallWorking) failedModels.push("small"); + if (!mediumWorking) failedModels.push("medium"); + + logger.warn("Some models failed the test:", { + failedModels, + small: this.configuredModels.small, + medium: this.configuredModels.medium + }); + } else { + logger.success("All configured models passed the test"); + } + } + + public async initialize(): Promise { + try { + if (this.initialized) { + logger.info("StudioLM already initialized, skipping initialization"); + return; + } + + logger.info("Starting StudioLM initialization..."); + const serverAvailable = await this.checkServerStatus(); + + if (!serverAvailable) { + throw new Error("LM Studio server is not available"); + } + + await this.fetchAvailableModels(); + await this.testTextModels(); + + this.initialized = true; + logger.success("StudioLM initialization complete"); + } catch (error) { + logger.error("StudioLM initialization failed:", { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined + }); + throw error; + } + } + + public getAvailableModels(): StudioLMModel[] { + return this.availableModels; + } + + public isInitialized(): boolean { + return this.initialized; + } + + public async generateText(params: GenerateTextParams, isInitialized = false): Promise { + try { + // Log entry point with all parameters + logger.info("StudioLM generateText entry:", { + isInitialized, + currentInitState: this.initialized, + managerInitState: this.isInitialized(), + modelClass: params.modelClass, + contextLength: params.context?.length, + timestamp: new Date().toISOString() + }); + + // Only initialize if not already initialized and not marked as initialized + if (!this.initialized && !isInitialized) { + throw new Error("StudioLM not initialized. Please initialize before generating text."); + } + + const messages: ChatMessage[] = [ + { role: "system", content: "You are a helpful AI assistant. Respond to the current request only." }, + { role: "user", content: params.context } + ]; + + logger.info("StudioLM preparing request:", { + model: params.modelClass === ModelClass.TEXT_LARGE ? + this.configuredModels.medium : + this.configuredModels.small, + messageCount: messages.length, + timestamp: new Date().toISOString() + }); + + logger.info("Incoming context structure:", { + contextLength: params.context.length, + hasAction: params.context.includes("action"), + runtime: !!params.runtime, + stopSequences: params.stopSequences + }); + + const request: ChatCompletionRequest = { + model: params.modelClass === ModelClass.TEXT_LARGE ? this.configuredModels.medium : this.configuredModels.small, + messages, + temperature: 0.7, + max_tokens: 8192, + stream: false + }; + + const response = await fetch(`${this.serverUrl}/v1/chat/completions`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify(request) + }); + + if (!response.ok) { + throw new Error(`StudioLM request failed: ${response.status}`); + } + + const result = await response.json() as ChatCompletionResponse; + + if (!result.choices?.[0]?.message?.content) { + throw new Error("No valid response content received from StudioLM"); + } + + let responseText = result.choices[0].message.content; + + // Log raw response for debugging + logger.info("Raw response structure:", { + responseLength: responseText.length, + hasAction: responseText.includes("action"), + hasThinkTag: responseText.includes("") + }); + + // Clean think tags if present + if (responseText.includes("")) { + logger.info("Cleaning think tags from response"); + responseText = responseText.replace(/[\s\S]*?<\/think>\n?/g, ""); + logger.info("Think tags removed from response"); + } + + logger.info("StudioLM request completed successfully:", { + responseLength: responseText.length, + hasThinkTags: responseText.includes(""), + timestamp: new Date().toISOString() + }); + + return responseText; + } catch (error) { + logger.error("StudioLM text generation error:", { + error: error instanceof Error ? error.message : String(error), + stack: error instanceof Error ? error.stack : undefined, + phase: "text generation", + timestamp: new Date().toISOString() + }); + throw error; + } + } +} \ No newline at end of file From 3bc73e5d5cfb008b71c50e3ec79a59ab1b1b126a Mon Sep 17 00:00:00 2001 From: AIFlow_ML Date: Tue, 25 Feb 2025 16:07:29 +0700 Subject: [PATCH 06/13] Added .env configuration for Ollama and StudioLM --- .../__tests__/text-summ.test.ts | 0 packages/plugin-local-ai/environment.ts | 36 ------------------ packages/plugin-local-ai/src/environment.ts | 37 ++++++++++++++++++- .../src/utils/ollamaManager.ts | 5 ++- .../src/utils/studiolmManager.ts | 4 +- 5 files changed, 40 insertions(+), 42 deletions(-) delete mode 100644 packages/plugin-local-ai/__tests__/text-summ.test.ts delete mode 100644 packages/plugin-local-ai/environment.ts diff --git a/packages/plugin-local-ai/__tests__/text-summ.test.ts b/packages/plugin-local-ai/__tests__/text-summ.test.ts deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/packages/plugin-local-ai/environment.ts b/packages/plugin-local-ai/environment.ts deleted file mode 100644 index c218e7e1ae6..00000000000 --- a/packages/plugin-local-ai/environment.ts +++ /dev/null @@ -1,36 +0,0 @@ -import type { IAgentRuntime } from "@elizaos/core"; -import { z } from "zod"; - -export const nodeEnvSchema = z.object({ - VITS_VOICE: z.string().optional(), - VITS_MODEL: z.string().optional(), -}); - -export type NodeConfig = z.infer; - -export async function validateNodeConfig( - runtime: IAgentRuntime -): Promise { - try { - const voiceSettings = runtime.character.settings?.voice; - - // Only include what's absolutely required - const config = { - // VITS settings - VITS_VOICE: voiceSettings?.model || process.env.VITS_VOICE, - VITS_MODEL: process.env.VITS_MODEL, - }; - - return nodeEnvSchema.parse(config); - } catch (error) { - if (error instanceof z.ZodError) { - const errorMessages = error.errors - .map((err) => `${err.path.join(".")}: ${err.message}`) - .join("\n"); - throw new Error( - `Node configuration validation failed:\n${errorMessages}` - ); - } - throw error; - } -} diff --git a/packages/plugin-local-ai/src/environment.ts b/packages/plugin-local-ai/src/environment.ts index eaec9981059..8e879f52e65 100644 --- a/packages/plugin-local-ai/src/environment.ts +++ b/packages/plugin-local-ai/src/environment.ts @@ -7,6 +7,21 @@ export const configSchema = z.object({ USE_LOCAL_AI: z.boolean().default(true), USE_STUDIOLM_TEXT_MODELS: z.boolean().default(false), USE_OLLAMA_TEXT_MODELS: z.boolean().default(false), + + // Ollama Configuration + OLLAMA_SERVER_URL: z.string().default("http://localhost:11434"), + OLLAMA_MODEL: z.string().default("deepseek-r1-distill-qwen-7b"), + USE_OLLAMA_EMBEDDING: z.boolean().default(false), + OLLAMA_EMBEDDING_MODEL: z.string().default(""), + SMALL_OLLAMA_MODEL: z.string().default("deepseek-r1:1.5b"), + MEDIUM_OLLAMA_MODEL: z.string().default("deepseek-r1:7b"), + LARGE_OLLAMA_MODEL: z.string().default("deepseek-r1:7b"), + + // StudioLM Configuration + STUDIOLM_SERVER_URL: z.string().default("http://localhost:1234"), + STUDIOLM_SMALL_MODEL: z.string().default("lmstudio-community/deepseek-r1-distill-qwen-1.5b"), + STUDIOLM_MEDIUM_MODEL: z.string().default("deepseek-r1-distill-qwen-7b"), + STUDIOLM_EMBEDDING_MODEL: z.union([z.boolean(), z.string()]).default(false), }); export type Config = z.infer; @@ -41,7 +56,9 @@ export async function validateConfig( logger.info("Raw environment variables:", { USE_LOCAL_AI: process.env.USE_LOCAL_AI, USE_STUDIOLM_TEXT_MODELS: process.env.USE_STUDIOLM_TEXT_MODELS, - USE_OLLAMA_TEXT_MODELS: process.env.USE_OLLAMA_TEXT_MODELS + USE_OLLAMA_TEXT_MODELS: process.env.USE_OLLAMA_TEXT_MODELS, + OLLAMA_SERVER_URL: process.env.OLLAMA_SERVER_URL, + STUDIOLM_SERVER_URL: process.env.STUDIOLM_SERVER_URL }); // Parse environment variables with proper boolean conversion @@ -49,6 +66,7 @@ export async function validateConfig( USE_LOCAL_AI: true, // Always true USE_STUDIOLM_TEXT_MODELS: config.USE_STUDIOLM_TEXT_MODELS === 'true', USE_OLLAMA_TEXT_MODELS: config.USE_OLLAMA_TEXT_MODELS === 'true', + USE_OLLAMA_EMBEDDING: config.USE_OLLAMA_EMBEDDING === 'true', }; logger.info("Parsed boolean configuration:", booleanConfig); @@ -56,7 +74,22 @@ export async function validateConfig( // Validate text model source configuration validateModelConfig(booleanConfig); - const validatedConfig = configSchema.parse(booleanConfig); + // Create full config with all values + const fullConfig = { + ...booleanConfig, + OLLAMA_SERVER_URL: config.OLLAMA_SERVER_URL || "http://localhost:11434", + OLLAMA_MODEL: config.OLLAMA_MODEL || "deepseek-r1-distill-qwen-7b", + OLLAMA_EMBEDDING_MODEL: config.OLLAMA_EMBEDDING_MODEL || "", + SMALL_OLLAMA_MODEL: config.SMALL_OLLAMA_MODEL || "deepseek-r1:1.5b", + MEDIUM_OLLAMA_MODEL: config.MEDIUM_OLLAMA_MODEL || "deepseek-r1:7b", + LARGE_OLLAMA_MODEL: config.LARGE_OLLAMA_MODEL || "deepseek-r1:7b", + STUDIOLM_SERVER_URL: config.STUDIOLM_SERVER_URL || "http://localhost:1234", + STUDIOLM_SMALL_MODEL: config.STUDIOLM_SMALL_MODEL || "lmstudio-community/deepseek-r1-distill-qwen-1.5b", + STUDIOLM_MEDIUM_MODEL: config.STUDIOLM_MEDIUM_MODEL || "deepseek-r1-distill-qwen-7b", + STUDIOLM_EMBEDDING_MODEL: config.STUDIOLM_EMBEDDING_MODEL || false, + }; + + const validatedConfig = configSchema.parse(fullConfig); logger.info("Final validated configuration:", validatedConfig); diff --git a/packages/plugin-local-ai/src/utils/ollamaManager.ts b/packages/plugin-local-ai/src/utils/ollamaManager.ts index 8312b6900eb..03cb9b7f5ca 100644 --- a/packages/plugin-local-ai/src/utils/ollamaManager.ts +++ b/packages/plugin-local-ai/src/utils/ollamaManager.ts @@ -1,4 +1,5 @@ import { logger, ModelClass, type GenerateTextParams } from "@elizaos/core"; + import fetch from "node-fetch"; interface OllamaModel { @@ -25,8 +26,8 @@ export class OllamaManager { private initialized = false; private availableModels: OllamaModel[] = []; private configuredModels = { - small: "deepseek-r1:1.5b", - medium: "deepseek-r1:7b" + small: process.env.SMALL_OLLAMA_MODEL || "deepseek-r1:1.5b", + medium: process.env.MEDIUM_OLLAMA_MODEL || "deepseek-r1:7b" }; private constructor() { diff --git a/packages/plugin-local-ai/src/utils/studiolmManager.ts b/packages/plugin-local-ai/src/utils/studiolmManager.ts index 73aa505d466..145960e4fe0 100644 --- a/packages/plugin-local-ai/src/utils/studiolmManager.ts +++ b/packages/plugin-local-ai/src/utils/studiolmManager.ts @@ -39,8 +39,8 @@ export class StudioLMManager { private initialized = false; private availableModels: StudioLMModel[] = []; private configuredModels = { - small: "lmstudio-community/deepseek-r1-distill-qwen-1.5b", - medium: "deepseek-r1-distill-qwen-7b" + small: process.env.STUDIOLM_SMALL_MODEL || "lmstudio-community/deepseek-r1-distill-qwen-1.5b", + medium: process.env.STUDIOLM_MEDIUM_MODEL || "deepseek-r1-distill-qwen-7b" }; private constructor() { From ff335976d2201ccc62a0c93c83c8b441005376fd Mon Sep 17 00:00:00 2001 From: Sayo Date: Tue, 25 Feb 2025 16:54:27 +0530 Subject: [PATCH 07/13] remove default values from env and added readme --- .env.example | 47 ++++----- packages/plugin-local-ai/README.md | 153 +++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 26 deletions(-) diff --git a/.env.example b/.env.example index 9f8f970986c..8808b3256ea 100644 --- a/.env.example +++ b/.env.example @@ -30,37 +30,32 @@ SOLANA_PRIVATE_KEY= BIRDEYE_API_KEY= # Swarm settings -COMPLIANCE_OFFICER_DISCORD_APPLICATION_ID= -COMPLIANCE_OFFICER_DISCORD_API_TOKEN= +COMMUNITY_MANAGER_DISCORD_APPLICATION_ID= +COMMUNITY_MANAGER_DISCORD_API_TOKEN= + +SOCIAL_MEDIA_MANAGER_DISCORD_APPLICATION_ID= +SOCIAL_MEDIA_MANAGER_DISCORD_API_TOKEN= + +COUNSELOR_DISCORD_APPLICATION_ID= +COUNSELOR_DISCORD_API_TOKEN= -# ================================== # Local AI Configuration -# ================================== -USE_LOCAL_AI=true -USE_STUDIOLM_TEXT_MODELS=false -USE_OLLAMA_TEXT_MODELS=false +USE_LOCAL_AI= +USE_STUDIOLM_TEXT_MODELS= +USE_OLLAMA_TEXT_MODELS= # Ollama Configuration -OLLAMA_SERVER_URL=http://localhost:11434 -OLLAMA_MODEL=deepseek-r1-distill-qwen-7b -USE_OLLAMA_EMBEDDING=false +OLLAMA_SERVER_URL= +OLLAMA_MODEL= +USE_OLLAMA_EMBEDDING= OLLAMA_EMBEDDING_MODEL= -SMALL_OLLAMA_MODEL=deepseek-r1:1.5b -MEDIUM_OLLAMA_MODEL=deepseek-r1:7b -LARGE_OLLAMA_MODEL=deepseek-r1:7b +SMALL_OLLAMA_MODEL= +MEDIUM_OLLAMA_MODEL= +LARGE_OLLAMA_MODEL= # StudioLM Configuration -STUDIOLM_SERVER_URL=http://localhost:1234 -STUDIOLM_SMALL_MODEL=lmstudio-community/deepseek-r1-distill-qwen-1.5b -STUDIOLM_MEDIUM_MODEL=deepseek-r1-distill-qwen-7b -STUDIOLM_EMBEDDING_MODEL=false - -COMMUNITY_MANAGER_DISCORD_APPLICATION_ID= -COMMUNITY_MANAGER_DISCORD_API_TOKEN= - -SOCIAL_MEDIA_MANAGER_DISCORD_APPLICATION_ID= -SOCIAL_MEDIA_MANAGER_DISCORD_API_TOKEN= - -COUNSELOR_DISCORD_APPLICATION_ID= -COUNSELOR_DISCORD_API_TOKEN= \ No newline at end of file +STUDIOLM_SERVER_URL= +STUDIOLM_SMALL_MODEL= +STUDIOLM_MEDIUM_MODEL= +STUDIOLM_EMBEDDING_MODEL= diff --git a/packages/plugin-local-ai/README.md b/packages/plugin-local-ai/README.md index e69de29bb2d..e9bcc026336 100644 --- a/packages/plugin-local-ai/README.md +++ b/packages/plugin-local-ai/README.md @@ -0,0 +1,153 @@ +# Local AI Plugin + +This plugin provides local AI model capabilities through the ElizaOS platform, supporting text generation, image analysis, speech synthesis, and audio transcription. + +## Usage + +Add the plugin to your character configuration: + +```json +"plugins": ["@elizaos/plugin-local-ai"] +``` + +## Configuration + +The plugin requires these environment variables (can be set in .env file or character settings): + +```json +"settings": { + "USE_LOCAL_AI": true, + "USE_STUDIOLM_TEXT_MODELS": false, + "USE_OLLAMA_TEXT_MODELS": false, + + "OLLAMA_SERVER_URL": "http://localhost:11434", + "OLLAMA_MODEL": "deepseek-r1-distill-qwen-7b", + "USE_OLLAMA_EMBEDDING": false, + "OLLAMA_EMBEDDING_MODEL": "", + "SMALL_OLLAMA_MODEL": "deepseek-r1:1.5b", + "MEDIUM_OLLAMA_MODEL": "deepseek-r1:7b", + "LARGE_OLLAMA_MODEL": "deepseek-r1:7b", + + "STUDIOLM_SERVER_URL": "http://localhost:1234", + "STUDIOLM_SMALL_MODEL": "lmstudio-community/deepseek-r1-distill-qwen-1.5b", + "STUDIOLM_MEDIUM_MODEL": "deepseek-r1-distill-qwen-7b", + "STUDIOLM_EMBEDDING_MODEL": false +} +``` + +Or in `.env` file: +```env +# Local AI Configuration +USE_LOCAL_AI=true +USE_STUDIOLM_TEXT_MODELS=false +USE_OLLAMA_TEXT_MODELS=false + +# Ollama Configuration +OLLAMA_SERVER_URL=http://localhost:11434 +OLLAMA_MODEL=deepseek-r1-distill-qwen-7b +USE_OLLAMA_EMBEDDING=false +OLLAMA_EMBEDDING_MODEL= +SMALL_OLLAMA_MODEL=deepseek-r1:1.5b +MEDIUM_OLLAMA_MODEL=deepseek-r1:7b +LARGE_OLLAMA_MODEL=deepseek-r1:7b + +# StudioLM Configuration +STUDIOLM_SERVER_URL=http://localhost:1234 +STUDIOLM_SMALL_MODEL=lmstudio-community/deepseek-r1-distill-qwen-1.5b +STUDIOLM_MEDIUM_MODEL=deepseek-r1-distill-qwen-7b +STUDIOLM_EMBEDDING_MODEL=false +``` + +### Configuration Options + +#### Text Model Source (Choose One) +- `USE_STUDIOLM_TEXT_MODELS`: Enable StudioLM text models +- `USE_OLLAMA_TEXT_MODELS`: Enable Ollama text models +Note: Only one text model source can be enabled at a time + +#### Ollama Settings +- `OLLAMA_SERVER_URL`: Ollama API endpoint (default: http://localhost:11434) +- `OLLAMA_MODEL`: Default model for general use +- `USE_OLLAMA_EMBEDDING`: Enable Ollama for embeddings +- `OLLAMA_EMBEDDING_MODEL`: Model for embeddings when enabled +- `SMALL_OLLAMA_MODEL`: Model for lighter tasks +- `MEDIUM_OLLAMA_MODEL`: Model for standard tasks +- `LARGE_OLLAMA_MODEL`: Model for complex tasks + +#### StudioLM Settings +- `STUDIOLM_SERVER_URL`: StudioLM API endpoint (default: http://localhost:1234) +- `STUDIOLM_SMALL_MODEL`: Model for lighter tasks +- `STUDIOLM_MEDIUM_MODEL`: Model for standard tasks +- `STUDIOLM_EMBEDDING_MODEL`: Model for embeddings (or false to disable) + +## Features + +The plugin provides these model classes: +- `TEXT_SMALL`: Fast, efficient text generation using smaller models +- `TEXT_LARGE`: More capable text generation using larger models +- `IMAGE_DESCRIPTION`: Local image analysis using Florence-2 vision model +- `TEXT_TO_SPEECH`: Local text-to-speech synthesis +- `TRANSCRIPTION`: Local audio transcription using Whisper + +### Image Analysis +```typescript +const { title, description } = await runtime.useModel( + ModelClass.IMAGE_DESCRIPTION, + "https://example.com/image.jpg" +); +``` + +### Text-to-Speech +```typescript +const audioStream = await runtime.useModel( + ModelClass.TEXT_TO_SPEECH, + "Text to convert to speech" +); +``` + +### Audio Transcription +```typescript +const transcription = await runtime.useModel( + ModelClass.TRANSCRIPTION, + audioBuffer +); +``` + +### Text Generation +```typescript +// Using small model +const smallResponse = await runtime.useModel( + ModelClass.TEXT_SMALL, + { + context: "Generate a short response", + stopSequences: [] + } +); + +// Using large model +const largeResponse = await runtime.useModel( + ModelClass.TEXT_LARGE, + { + context: "Generate a detailed response", + stopSequences: [] + } +); +``` + +## Model Sources + +### 1. StudioLM (LM Studio) +- Local inference server for running various open models +- Supports chat completion API similar to OpenAI +- Configure with `USE_STUDIOLM_TEXT_MODELS=true` +- Supports both small and medium-sized models +- Optional embedding model support + +### 2. Ollama +- Local model server with optimized inference +- Supports various open models in GGUF format +- Configure with `USE_OLLAMA_TEXT_MODELS=true` +- Supports small, medium, and large models +- Optional embedding model support + +Note: The plugin validates that only one text model source is enabled at a time to prevent conflicts. \ No newline at end of file From 7bdb2d9ede79885386f19a10fd00e039c3bd7943 Mon Sep 17 00:00:00 2001 From: Sayo Date: Tue, 25 Feb 2025 16:56:10 +0530 Subject: [PATCH 08/13] Update index.ts --- packages/plugin-local-ai/src/index.ts | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/packages/plugin-local-ai/src/index.ts b/packages/plugin-local-ai/src/index.ts index 7b79075b8eb..4164fcaa7f7 100644 --- a/packages/plugin-local-ai/src/index.ts +++ b/packages/plugin-local-ai/src/index.ts @@ -1241,24 +1241,6 @@ export const localAIPlugin: Plugin = { ] } ], - routes: [ - { - path: "/health", - type: "GET", - handler: async (_req: unknown, res: { json: (data: unknown) => void }) => { - res.json({ - status: "healthy", - models: { - small: true, - large: true, - vision: true, - transcription: true, - tts: true - } - }); - } - } - ] }; export default localAIPlugin; From 72a8ce6d22aa02fc49af2c42829269cd5097f57a Mon Sep 17 00:00:00 2001 From: Sayo Date: Tue, 25 Feb 2025 17:00:19 +0530 Subject: [PATCH 09/13] cleaner --- packages/plugin-local-ai/src/index.ts | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/packages/plugin-local-ai/src/index.ts b/packages/plugin-local-ai/src/index.ts index 4164fcaa7f7..ffc033f839b 100644 --- a/packages/plugin-local-ai/src/index.ts +++ b/packages/plugin-local-ai/src/index.ts @@ -1,9 +1,6 @@ -import { type IAgentRuntime, logger, ModelClass, type Plugin } from "@elizaos/core"; import type { GenerateTextParams } from "@elizaos/core"; -import { exec } from "node:child_process"; -import * as Echogarden from "echogarden"; +import { type IAgentRuntime, logger, ModelClass, type Plugin } from "@elizaos/core"; import { EmbeddingModel, FlagEmbedding } from "fastembed"; -import fs from "node:fs"; import { getLlama, type Llama, @@ -12,24 +9,20 @@ import { type LlamaContextSequence, type LlamaModel } from "node-llama-cpp"; -// import { nodewhisper } from "nodejs-whisper"; -// import os from "node:os"; +import fs from "node:fs"; import path from "node:path"; import { Readable } from "node:stream"; import { fileURLToPath } from "node:url"; -// import { promisify } from "node:util"; -import { z } from "zod"; -// import https from "node:https"; -import { getPlatformManager } from "./utils/platform"; -import { TokenizerManager } from './utils/tokenizerManager'; +import { validateConfig } from "./environment"; import { MODEL_SPECS, type ModelSpec } from './types'; import { DownloadManager } from './utils/downloadManager'; -import { VisionManager } from './utils/visionManager'; +import { OllamaManager } from './utils/ollamaManager'; +import { getPlatformManager } from "./utils/platform"; +import { StudioLMManager } from './utils/studiolmManager'; +import { TokenizerManager } from './utils/tokenizerManager'; import { TranscribeManager } from './utils/transcribeManager'; import { TTSManager } from './utils/ttsManager'; -import { StudioLMManager } from './utils/studiolmManager'; -import { OllamaManager } from './utils/ollamaManager'; -import { validateConfig } from "./environment"; +import { VisionManager } from './utils/visionManager'; // const execAsync = promisify(exec); const __filename = fileURLToPath(import.meta.url); From 045d3262c55368730de72d928769182d3b77a7c4 Mon Sep 17 00:00:00 2001 From: Sayo Date: Tue, 25 Feb 2025 17:51:18 +0530 Subject: [PATCH 10/13] clean up server --- packages/agent/src/index.ts | 35 +- packages/agent/src/server/index.ts | 1028 +++-------------- .../src/utils/visionManager.ts | 2 +- 3 files changed, 190 insertions(+), 875 deletions(-) diff --git a/packages/agent/src/index.ts b/packages/agent/src/index.ts index b2be333b04c..404da9ec305 100644 --- a/packages/agent/src/index.ts +++ b/packages/agent/src/index.ts @@ -29,7 +29,8 @@ import { defaultCharacter } from "./single-agent/character.ts"; import { startScenario } from "./swarm/scenario.ts"; import swarm from "./swarm/index"; - +import * as path from "node:path"; +import * as fs from "node:fs"; export const wait = (minTime = 1000, maxTime = 3000) => { const waitTime = @@ -80,8 +81,6 @@ export function parseArguments(): { } } - - export async function createAgent( character: Character ): Promise { @@ -224,14 +223,28 @@ const checkPortAvailable = (port: number): Promise => { const startAgents = async () => { const server = new AgentServer(); + + // Assign the required functions first + server.startAgent = async (character) => { + logger.info(`Starting agent for character ${character.name}`); + return startAgent(character, server); + }; + server.loadCharacterTryPath = loadCharacterTryPath; + server.jsonToCharacter = jsonToCharacter; + let serverPort = Number.parseInt(settings.SERVER_PORT || "3000"); const args = parseArguments(); const charactersArg = args.characters || args.character; - let characters = []; + const characters: Character[] = []; - // Assign the character loading functions - server.loadCharacterTryPath = loadCharacterTryPath; - server.jsonToCharacter = jsonToCharacter; + // Add this before creating the AgentServer + const dataDir = path.join(process.cwd(), "data"); + try { + fs.accessSync(dataDir, fs.constants.W_OK); + logger.debug(`Data directory ${dataDir} is writable`); + } catch (error) { + logger.error(`Data directory ${dataDir} is not writable:`, error); + } if (args.swarm) { try { @@ -275,12 +288,6 @@ const startAgents = async () => { serverPort++; } - server.startAgent = async (character) => { - logger.info(`Starting agent for character ${character.name}`); - return startAgent(character, server); - }; - - server.start(serverPort); if (serverPort !== Number.parseInt(settings.SERVER_PORT || "3000")) { @@ -293,7 +300,7 @@ const startAgents = async () => { }; startAgents().catch((error) => { - logger.error("Unhandled error in startAgents:", error); + logger.error("Unhandled error in startAgents:", error.message); process.exit(1); }); diff --git a/packages/agent/src/server/index.ts b/packages/agent/src/server/index.ts index e980d5167b3..b12cc9dcb9d 100644 --- a/packages/agent/src/server/index.ts +++ b/packages/agent/src/server/index.ts @@ -1,36 +1,24 @@ import { - ChannelType, - composeContext, - generateMessageResponse, - generateObject, logger, - ModelClass, - stringToUuid, type Character, - type Content, type IAgentRuntime, - type Media, - type Memory, + type Provider, + type Route } from "@elizaos/core"; import bodyParser from "body-parser"; import cors from "cors"; -import express, { type Request as ExpressRequest } from "express"; +import express, { type Server } from "express"; import * as fs from "node:fs"; import * as path from "node:path"; -import { z } from "zod"; import { createApiRouter } from "./api/index.ts"; -import { hyperfiHandlerTemplate, messageHandlerTemplate } from "./helper.ts"; import replyAction from "./reply.ts"; -import { upload } from "./loader.ts"; -export interface ServerMiddleware { - ( +export type ServerMiddleware = ( req: express.Request, res: express.Response, next: express.NextFunction - ): void; -} +) => void; export interface ServerOptions { middlewares?: ServerMiddleware[]; @@ -38,873 +26,193 @@ export interface ServerOptions { export class AgentServer { public app: express.Application; - private agents: Map; // container management - private server: any; // Store server instance - public startAgent: (character: Character) => Promise; // Store startAgent function - public loadCharacterTryPath: (characterPath: string) => Promise; // Store loadCharacterTryPath function - public jsonToCharacter: (character: string | never) => Promise; // Store jsonToCharacter function - - constructor() { - logger.log("DirectClient constructor"); - this.app = express(); - this.app.use(cors()); - this.agents = new Map(); - - this.app.use(bodyParser.json()); - this.app.use(bodyParser.urlencoded({ extended: true })); - - if (options?.middlewares) { - for (const middleware of options.middlewares) { - this.app.use(middleware); - } - } - - // Serve both uploads and generated images - this.app.use( - "/media/uploads", - express.static(path.join(process.cwd(), "/data/uploads")) - ); - this.app.use( - "/media/generated", - express.static(path.join(process.cwd(), "/generatedImages")) - ); - - // Serve both uploads and generated images - this.app.use( - "/media/uploads", - express.static(path.join(process.cwd(), "/data/uploads")) - ); - this.app.use( - "/media/generated", - express.static(path.join(process.cwd(), "/generatedImages")) - ); - - const apiRouter = createApiRouter(this.agents, this); - this.app.use(apiRouter); - - // Define an interface that extends the Express Request interface - interface CustomRequest extends ExpressRequest { - file?: Express.Multer.File; - } - - // Update the route handler to use CustomRequest instead of express.Request - this.app.post( - "/:agentId/whisper", - upload.single("file"), - async (req: CustomRequest, res: express.Response) => { - const audioFile = req.file; // Access the uploaded file using req.file - const agentId = req.params.agentId; - - if (!audioFile) { - res.status(400).send("No audio file provided"); - return; - } - - let runtime = this.agents.get(agentId); - - // if runtime is null, look for runtime with the same name - if (!runtime) { - runtime = Array.from(this.agents.values()).find( - (a) => a.character.name.toLowerCase() === agentId.toLowerCase() - ); - } - - if (!runtime) { - res.status(404).send("Agent not found"); - return; - } - - const audioBuffer = fs.readFileSync(audioFile.path); - const transcription = await runtime.useModel( - ModelClass.TRANSCRIPTION, - audioBuffer - ); - - res.json({ text: transcription }); - } - ); - - this.app.post( - "/:agentId/message", - upload.single("file"), - async (req: express.Request, res: express.Response) => { - const agentId = req.params.agentId; - const roomId = stringToUuid( - req.body.roomId ?? `default-room-${agentId}` - ); - const userId = stringToUuid(req.body.userId ?? "user"); - - let runtime = this.agents.get(agentId); - - // if runtime is null, look for runtime with the same name - if (!runtime) { - runtime = Array.from(this.agents.values()).find( - (a) => a.character.name.toLowerCase() === agentId.toLowerCase() - ); - } - - if (!runtime) { - res.status(404).send("Agent not found"); - return; - } - - await runtime.ensureConnection({ - userId, - roomId, - userName: req.body.userName, - userScreenName: req.body.name, - source: "direct", - type: ChannelType.API, - }); - - const text = req.body.text; - // if empty text, directly return - if (!text) { - res.json([]); - return; - } - - const messageId = stringToUuid(Date.now().toString()); - - const attachments: Media[] = []; - if (req.file) { - const filePath = path.join( - process.cwd(), - "data", - "uploads", - req.file.filename - ); - attachments.push({ - id: Date.now().toString(), - url: filePath, - title: req.file.originalname, - source: "direct", - description: `Uploaded file: ${req.file.originalname}`, - text: "", - contentType: req.file.mimetype, - }); - } - - const content: Content = { - text, - attachments, - source: "direct", - inReplyTo: undefined, - }; - - const userMessage = { - content, - userId, - roomId, - agentId: runtime.agentId, - }; - - const memory: Memory = { - id: stringToUuid(`${messageId}-${userId}`), - ...userMessage, - agentId: runtime.agentId, - userId, - roomId, - content, - createdAt: Date.now(), - }; - - await runtime.messageManager.addEmbeddingToMemory(memory); - await runtime.messageManager.createMemory(memory); - - let state = await runtime.composeState(userMessage, { - agentName: runtime.character.name, - }); - - const context = composeContext({ - state, - template: messageHandlerTemplate, - }); + private agents: Map; + private server: any; // Change back to any since Server type isn't exported + public startAgent!: (character: Character) => Promise; // Add ! to indicate it will be assigned later + public loadCharacterTryPath!: (characterPath: string) => Promise; + public jsonToCharacter!: (character: unknown) => Promise; - const response = await generateMessageResponse({ - runtime: runtime, - context, - modelClass: ModelClass.TEXT_LARGE, - }); - - if (!response) { - res.status(500).send("No response from generateMessageResponse"); - return; - } - - // save response to memory - const responseMessage: Memory = { - id: stringToUuid(`${messageId}-${runtime.agentId}`), - ...userMessage, - userId: runtime.agentId, - content: response, - createdAt: Date.now(), - }; - - await runtime.messageManager.createMemory(responseMessage); - - state = await runtime.updateRecentMessageState(state); - - const replyHandler = async (message: Content) => { - res.json([message]); - return [memory]; - }; - - await runtime.processActions( - memory, - [responseMessage], - state, - replyHandler - ); - - await runtime.evaluate(memory, state); - } - ); - - this.app.post( - "/agents/:agentIdOrName/hyperfy/v1", - async (req: express.Request, res: express.Response) => { - // get runtime - const agentId = req.params.agentIdOrName; - let runtime = this.agents.get(agentId); - // if runtime is null, look for runtime with the same name - if (!runtime) { - runtime = Array.from(this.agents.values()).find( - (a) => a.character.name.toLowerCase() === agentId.toLowerCase() - ); - } - if (!runtime) { - res.status(404).send("Agent not found"); - return; - } - - // can we be in more than one hyperfy world at once - // but you may want the same context is multiple worlds - // this is more like an instanceId - const roomId = stringToUuid(req.body.roomId ?? "hyperfy"); - - const body = req.body; - - // hyperfy specific parameters - let nearby = []; - let availableEmotes = []; - - if (body.nearby) { - nearby = body.nearby; - } - if (body.messages) { - // loop on the messages and record the memories - // might want to do this in parallel - for (const msg of body.messages) { - const parts = msg.split(/:\s*/); - const mUserId = stringToUuid(parts[0]); - await runtime.ensureConnection({ - userId: mUserId, - roomId, // where - userName: parts[0], // username - userScreenName: parts[0], // userScreeName? - source: "hyperfy", - type: ChannelType.WORLD, - }); - const content: Content = { - text: parts[1] || "", - attachments: [], - source: "hyperfy", - inReplyTo: undefined, - }; - const memory: Memory = { - id: stringToUuid(msg), - agentId: runtime.agentId, - userId: mUserId, - roomId, - content, - }; - await runtime.messageManager.createMemory(memory); - } - } - if (body.availableEmotes) { - availableEmotes = body.availableEmotes; - } - - const content: Content = { - // we need to compose who's near and what emotes are available - text: JSON.stringify(req.body), - attachments: [], - source: "hyperfy", - inReplyTo: undefined, - }; - - const userId = stringToUuid("hyperfy"); - const userMessage = { - content, - userId, - roomId, - agentId: runtime.agentId, - }; - - const state = await runtime.composeState(userMessage, { - agentName: runtime.character.name, - }); - - let template = hyperfiHandlerTemplate; - template = template.replace("{{emotes}}", availableEmotes.join("|")); - template = template.replace("{{nearby}}", nearby.join("|")); - const context = composeContext({ - state, - template, - }); - - function createHyperfiOutSchema( - nearby: string[], - availableEmotes: string[] - ) { - const lookAtSchema = - nearby.length > 1 - ? z - .union( - nearby.map((item) => z.literal(item)) as [ - z.ZodLiteral, - z.ZodLiteral, - ...z.ZodLiteral[] - ] - ) - .nullable() - : nearby.length === 1 - ? z.literal(nearby[0]).nullable() - : z.null(); // Fallback for empty array - - const emoteSchema = - availableEmotes.length > 1 - ? z - .union( - availableEmotes.map((item) => z.literal(item)) as [ - z.ZodLiteral, - z.ZodLiteral, - ...z.ZodLiteral[] - ] - ) - .nullable() - : availableEmotes.length === 1 - ? z.literal(availableEmotes[0]).nullable() - : z.null(); // Fallback for empty array - - return z.object({ - lookAt: lookAtSchema, - emote: emoteSchema, - say: z.string().nullable(), - actions: z.array(z.string()).nullable(), - }); - } + constructor(options?: ServerOptions) { + try { + logger.log("Initializing AgentServer..."); + this.app = express(); + this.agents = new Map(); + + // Core middleware setup + this.app.use(cors()); + this.app.use(bodyParser.json()); + this.app.use(bodyParser.urlencoded({ extended: true })); + + // Custom middleware setup + if (options?.middlewares) { + for (const middleware of options.middlewares) { + this.app.use(middleware); + } + } - // Define the schema for the expected output - const hyperfiOutSchema = createHyperfiOutSchema( - nearby, - availableEmotes - ); + // Static file serving setup + const uploadsPath = path.join(process.cwd(), "/data/uploads"); + const generatedPath = path.join(process.cwd(), "/generatedImages"); + fs.mkdirSync(uploadsPath, { recursive: true }); + fs.mkdirSync(generatedPath, { recursive: true }); + + this.app.use("/media/uploads", express.static(uploadsPath)); + this.app.use("/media/generated", express.static(generatedPath)); - // Call LLM - const response = await generateObject({ - runtime, - context, - modelClass: ModelClass.TEXT_SMALL, - schema: hyperfiOutSchema, - }); + // API Router setup + const apiRouter = createApiRouter(this.agents, this); + this.app.use(apiRouter); - if (!response) { - res.status(500).send("No response from generateMessageResponse"); - return; + logger.success("AgentServer initialization complete"); + } catch (error) { + logger.error("Failed to initialize AgentServer:", error); + throw error; } + } - let hfOut; + public registerAgent(runtime: IAgentRuntime) { try { - hfOut = hyperfiOutSchema.parse(response.object); - } catch { - logger.error("cant serialize response", response.object); - res.status(500).send("Error in LLM response, try again"); - return; - } - - // do this in the background - new Promise((resolve) => { - const contentObj: Content = { - text: hfOut.say, - }; - - if (hfOut.lookAt !== null || hfOut.emote !== null) { - contentObj.text += ". Then I "; - if (hfOut.lookAt !== null) { - contentObj.text += `looked at ${hfOut.lookAt}`; - if (hfOut.emote !== null) { - contentObj.text += " and "; - } + if (!runtime) { + throw new Error("Attempted to register null/undefined runtime"); } - if (hfOut.emote !== null) { - contentObj.text += `emoted ${hfOut.emote}`; + if (!runtime.agentId) { + throw new Error("Runtime missing agentId"); + } + if (!runtime.character) { + throw new Error("Runtime missing character configuration"); } - } - - if (hfOut.actions !== null) { - // content can only do one action - contentObj.action = hfOut.actions[0]; - } - - // save response to memory - const responseMessage = { - ...userMessage, - userId: runtime.agentId, - content: contentObj, - }; - - runtime.messageManager.createMemory(responseMessage).then(() => { - const messageId = stringToUuid(Date.now().toString()); - const memory: Memory = { - id: messageId, - agentId: runtime.agentId, - userId, - roomId, - content, - createdAt: Date.now(), - }; - - // run evaluators (generally can be done in parallel with processActions) - // can an evaluator modify memory? it could but currently doesn't - runtime.evaluate(memory, state).then(() => { - // only need to call if responseMessage.content.action is set - if (contentObj.action) { - // pass memory (query) to any actions to call - runtime.processActions( - memory, - [responseMessage], - state, - async (_newMessages) => { - // FIXME: this is supposed override what the LLM said/decided - // but the promise doesn't make this possible - //message = newMessages; - return [memory]; - } - ); // 0.674s - } - resolve(true); - }); - }); - }); - res.json({ response: hfOut }); - } - ); - this.app.post( - "/:agentId/image", - async (req: express.Request, res: express.Response) => { - const agentId = req.params.agentId; - const agent = this.agents.get(agentId); - if (!agent) { - res.status(404).send("Agent not found"); - return; - } - const images = await agent.useModel(ModelClass.IMAGE, { ...req.body }); - const imagesRes: { image: string; caption: string }[] = []; - if (images.data && images.data.length > 0) { - for (let i = 0; i < images.data.length; i++) { - const caption = await agent.useModel( - ModelClass.IMAGE_DESCRIPTION, - images.data[i] - ); - imagesRes.push({ - image: images.data[i], - caption: caption.title, - }); - } - } - res.json({ images: imagesRes }); - } - ); + logger.debug(`Registering agent: ${runtime.agentId} (${runtime.character.name})`); + + // Register the agent + this.agents.set(runtime.agentId, runtime); + logger.debug(`Agent ${runtime.agentId} added to agents map`); + + // Register TEE plugin if present + const teePlugin = runtime.plugins.find(p => p.name === "phala-tee-plugin"); + if (teePlugin) { + logger.debug(`Found TEE plugin for agent ${runtime.agentId}`); + for (const provider of teePlugin.providers) { + runtime.registerProvider(provider); + logger.debug(`Registered TEE provider: ${provider.name}`); + } + for (const action of teePlugin.actions) { + runtime.registerAction(action); + logger.debug(`Registered TEE action: ${action.name}`); + } + } - this.app.post( - "/fine-tune", - async (req: express.Request, res: express.Response) => { - try { - const response = await fetch("https://api.bageldb.ai/api/v1/asset", { - method: "POST", - headers: { - "Content-Type": "application/json", - "X-API-KEY": `${process.env.BAGEL_API_KEY}`, - }, - body: JSON.stringify(req.body), - }); + // Register reply action + runtime.registerAction(replyAction); + logger.debug(`Registered reply action for agent ${runtime.agentId}`); + + // Register routes + logger.debug(`Registering ${runtime.routes.length} custom routes for agent ${runtime.agentId}`); + for (const route of runtime.routes) { + const routePath = route.path; + try { + switch (route.type) { + case "GET": + this.app.get(routePath, (req, res) => route.handler(req, res)); + break; + case "POST": + this.app.post(routePath, (req, res) => route.handler(req, res)); + break; + case "PUT": + this.app.put(routePath, (req, res) => route.handler(req, res)); + break; + case "DELETE": + this.app.delete(routePath, (req, res) => route.handler(req, res)); + break; + default: + logger.error(`Unknown route type: ${route.type} for path ${routePath}`); + continue; + } + logger.debug(`Registered ${route.type} route: ${routePath}`); + } catch (error) { + logger.error(`Failed to register route ${route.type} ${routePath}:`, error); + throw error; + } + } - const data = await response.json(); - res.json(data); + logger.success(`Successfully registered agent ${runtime.agentId} (${runtime.character.name})`); } catch (error) { - res.status(500).json({ - error: - "Please create an account at bakery.bagel.net and get an API key. Then set the BAGEL_API_KEY environment variable.", - details: error.message, - }); + logger.error("Failed to register agent:", error); + throw error; } - } - ); + } - this.app.get( - "/fine-tune/:assetId", - async (req: express.Request, res: express.Response) => { - const assetId = req.params.assetId; - const downloadDir = path.join(process.cwd(), "downloads", assetId); + public unregisterAgent(runtime: IAgentRuntime) { + this.agents.delete(runtime.agentId); + } - logger.log("Download directory:", downloadDir); + public registerMiddleware(middleware: ServerMiddleware) { + this.app.use(middleware); + } + public start(port: number) { try { - logger.log("Creating directory..."); - await fs.promises.mkdir(downloadDir, { recursive: true }); - - logger.log("Fetching file..."); - const fileResponse = await fetch( - `https://api.bageldb.ai/api/v1/asset/${assetId}/download`, - { - headers: { - "X-API-KEY": `${process.env.BAGEL_API_KEY}`, - }, + if (!port || typeof port !== 'number') { + throw new Error(`Invalid port number: ${port}`); } - ); - - if (!fileResponse.ok) { - throw new Error( - `API responded with status ${ - fileResponse.status - }: ${await fileResponse.text()}` - ); - } - - logger.log("Response headers:", fileResponse.headers); - - const fileName = - fileResponse.headers - .get("content-disposition") - ?.split("filename=")[1] - ?.replace(/"/g, /* " */ "") || "default_name.txt"; - - logger.log("Saving as:", fileName); - - const arrayBuffer = await fileResponse.arrayBuffer(); - const buffer = Buffer.from(arrayBuffer); - - const filePath = path.join(downloadDir, fileName); - logger.log("Full file path:", filePath); - - await fs.promises.writeFile(filePath, buffer); + + logger.debug(`Starting server on port ${port}...`); + logger.debug(`Current agents count: ${this.agents.size}`); + logger.debug(`Environment: ${process.env.NODE_ENV}`); + + this.server = this.app.listen(port, () => { + logger.success( + `REST API bound to 0.0.0.0:${port}. If running locally, access it at http://localhost:${port}.` + ); + logger.debug(`Active agents: ${this.agents.size}`); + this.agents.forEach((agent, id) => { + logger.debug(`- Agent ${id}: ${agent.character.name}`); + }); + }); - // Verify file was written - const stats = await fs.promises.stat(filePath); - logger.log("File written successfully. Size:", stats.size, "bytes"); + // Enhanced graceful shutdown + const gracefulShutdown = async () => { + logger.log("Received shutdown signal, initiating graceful shutdown..."); + + // Stop all agents first + logger.debug("Stopping all agents..."); + for (const [id, agent] of this.agents.entries()) { + try { + agent.stop(); + logger.debug(`Stopped agent ${id}`); + } catch (error) { + logger.error(`Error stopping agent ${id}:`, error); + } + } + + // Close server + this.server.close(() => { + logger.success("Server closed successfully"); + process.exit(0); + }); + + // Force close after timeout + setTimeout(() => { + logger.error("Could not close connections in time, forcing shutdown"); + process.exit(1); + }, 5000); + }; - res.json({ - success: true, - message: "Single file downloaded successfully", - downloadPath: downloadDir, - fileCount: 1, - fileName: fileName, - fileSize: stats.size, - }); + process.on("SIGTERM", gracefulShutdown); + process.on("SIGINT", gracefulShutdown); + + logger.debug("Shutdown handlers registered"); } catch (error) { - logger.error("Detailed error:", error); - res.status(500).json({ - error: "Failed to download files from BagelDB", - details: error.message, - stack: error.stack, - }); - } - } - ); - - this.app.post("/:agentId/speak", async (req, res) => { - const agentId = req.params.agentId; - const roomId = stringToUuid(req.body.roomId ?? `default-room-${agentId}`); - const userId = stringToUuid(req.body.userId ?? "user"); - const text = req.body.text; - - if (!text) { - res.status(400).send("No text provided"); - return; - } - - let runtime = this.agents.get(agentId); - - // if runtime is null, look for runtime with the same name - if (!runtime) { - runtime = Array.from(this.agents.values()).find( - (a) => a.character.name.toLowerCase() === agentId.toLowerCase() - ); - } - - if (!runtime) { - res.status(404).send("Agent not found"); - return; - } - - try { - // Process message through agent (same as /message endpoint) - await runtime.ensureConnection({ - userId, - roomId, - userName: req.body.userName, - userScreenName: req.body.name, - source: "direct", - type: ChannelType.API, - }); - - const messageId = stringToUuid(Date.now().toString()); - - const content: Content = { - text, - attachments: [], - source: "direct", - inReplyTo: undefined, - }; - - const userMessage = { - content, - userId, - roomId, - agentId: runtime.agentId, - }; - - const memory: Memory = { - id: messageId, - agentId: runtime.agentId, - userId, - roomId, - content, - createdAt: Date.now(), - }; - - await runtime.messageManager.createMemory(memory); - - const state = await runtime.composeState(userMessage, { - agentName: runtime.character.name, - }); - - const context = composeContext({ - state, - template: messageHandlerTemplate, - }); - - const response = await generateMessageResponse({ - runtime: runtime, - context, - modelClass: ModelClass.TEXT_LARGE, - }); - - // save response to memory - const responseMessage = { - ...userMessage, - userId: runtime.agentId, - content: response, - }; - - await runtime.messageManager.createMemory(responseMessage); - - if (!response) { - res.status(500).send("No response from generateMessageResponse"); - return; + logger.error("Failed to start server:", error); + throw error; } - - await runtime.evaluate(memory, state); - - const _result = await runtime.processActions( - memory, - [responseMessage], - state, - async () => { - return [memory]; - } - ); - - // Get the text to convert to speech - const textToSpeak = response.text; - - const speechResponse = await runtime.useModel( - ModelClass.TEXT_TO_SPEECH, - textToSpeak - ); - - if (!speechResponse.ok) { - throw new Error(`ElevenLabs API error: ${speechResponse.statusText}`); - } - - const audioBuffer = await speechResponse.arrayBuffer(); - - // Set appropriate headers for audio streaming - res.set({ - "Content-Type": "audio/mpeg", - "Transfer-Encoding": "chunked", - }); - - res.send(Buffer.from(audioBuffer)); - } catch (error) { - logger.error("Error processing message or generating speech:", error); - res.status(500).json({ - error: "Error processing message or generating speech", - details: error.message, - }); - } - }); - - this.app.post("/:agentId/tts", async (req, res) => { - const text = req.body.text; - - if (!text) { - res.status(400).send("No text provided"); - return; - } - - try { - // Convert to speech using ElevenLabs - const elevenLabsApiUrl = `https://api.elevenlabs.io/v1/text-to-speech/${process.env.ELEVENLABS_VOICE_ID}`; - const apiKey = process.env.ELEVENLABS_XI_API_KEY; - - if (!apiKey) { - throw new Error("ELEVENLABS_XI_API_KEY not configured"); - } - - // TODO: Replace the process.env with settings from the character read from the database - - const speechResponse = await fetch(elevenLabsApiUrl, { - method: "POST", - headers: { - "Content-Type": "application/json", - "xi-api-key": apiKey, - }, - body: JSON.stringify({ - text, - model_id: - process.env.ELEVENLABS_MODEL_ID || "eleven_multilingual_v2", - voice_settings: { - stability: Number.parseFloat( - process.env.ELEVENLABS_VOICE_STABILITY || "0.5" - ), - similarity_boost: Number.parseFloat( - process.env.ELEVENLABS_VOICE_SIMILARITY_BOOST || "0.9" - ), - style: Number.parseFloat( - process.env.ELEVENLABS_VOICE_STYLE || "0.66" - ), - use_speaker_boost: - process.env.ELEVENLABS_VOICE_USE_SPEAKER_BOOST === "true", - }, - }), - }); - - if (!speechResponse.ok) { - throw new Error(`ElevenLabs API error: ${speechResponse.statusText}`); - } - - const audioBuffer = await speechResponse.arrayBuffer(); - - res.set({ - "Content-Type": "audio/mpeg", - "Transfer-Encoding": "chunked", - }); - - res.send(Buffer.from(audioBuffer)); - } catch (error) { - logger.error("Error processing message or generating speech:", error); - res.status(500).json({ - error: "Error processing message or generating speech", - details: error.message, - }); - } - }); - } - - // agent/src/index.ts:startAgent calls this - public registerAgent(runtime: IAgentRuntime) { - // register any plugin endpoints? - // but once and only once - this.agents.set(runtime.agentId, runtime); - // TODO: This is a hack to register the tee plugin. Remove this once we have a better way to do it. - const teePlugin = runtime.plugins.find( - (p) => p.name === "phala-tee-plugin" - ); - if (teePlugin) { - for (const provider of teePlugin.providers) { - runtime.registerProvider(provider); - } - for (const action of teePlugin.actions) { - runtime.registerAction(action); - } } - runtime.registerAction(replyAction); - // for each route on each plugin, add it to the router - for (const route of runtime.routes) { - // if the path hasn't been added yet, add it - switch (route.type) { - case "GET": - this.app.get(route.path, (req: any, res: any) => - route.handler(req, res) - ); - break; - case "POST": - this.app.post(route.path, (req: any, res: any) => - route.handler(req, res) - ); - break; - case "PUT": - this.app.put(route.path, (req: any, res: any) => - route.handler(req, res) - ); - break; - case "DELETE": - this.app.delete(route.path, (req: any, res: any) => - route.handler(req, res) - ); - break; - default: - logger.error(`Unknown route type: ${route.type}`); - } - } - } - - public unregisterAgent(runtime: IAgentRuntime) { - this.agents.delete(runtime.agentId); - } - - public registerMiddleware(middleware: ServerMiddleware) { - this.app.use(middleware); - } - - public start(port: number) { - this.server = this.app.listen(port, () => { - logger.success( - `REST API bound to 0.0.0.0:${port}. If running locally, access it at http://localhost:${port}.` - ); - }); - - // Handle graceful shutdown - const gracefulShutdown = () => { - logger.log("Received shutdown signal, closing server..."); - this.server.close(() => { - logger.success("Server closed successfully"); - process.exit(0); - }); - // Force close after 5 seconds if server hasn't closed - setTimeout(() => { - logger.error( - "Could not close connections in time, forcefully shutting down" - ); - process.exit(1); - }, 5000); - }; - - // Handle different shutdown signals - process.on("SIGTERM", gracefulShutdown); - process.on("SIGINT", gracefulShutdown); - } - - public async stop() { - if (this.server) { - this.server.close(() => { - logger.success("Server stopped"); - }); + public async stop() { + if (this.server) { + this.server.close(() => { + logger.success("Server stopped"); + }); + } } - } } diff --git a/packages/plugin-local-ai/src/utils/visionManager.ts b/packages/plugin-local-ai/src/utils/visionManager.ts index f639692929e..3f42ffdcce9 100644 --- a/packages/plugin-local-ai/src/utils/visionManager.ts +++ b/packages/plugin-local-ai/src/utils/visionManager.ts @@ -318,7 +318,7 @@ export class VisionManager { private async fetchImage(url: string): Promise<{ buffer: Buffer; mimeType: string }> { try { - logger.info("Fetching image from URL:", { url }); + logger.info(`Fetching image from URL: ${url.slice(0, 100)}...`); // Handle data URLs differently if (url.startsWith('data:')) { From f7097306163bf831e9bf71d906af46a2f00f0d46 Mon Sep 17 00:00:00 2001 From: Sayo Date: Tue, 25 Feb 2025 17:57:47 +0530 Subject: [PATCH 11/13] type changes --- packages/agent/src/server/index.ts | 11 +++++------ packages/agent/src/single-agent/character.ts | 4 ++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/packages/agent/src/server/index.ts b/packages/agent/src/server/index.ts index b12cc9dcb9d..1b009f97814 100644 --- a/packages/agent/src/server/index.ts +++ b/packages/agent/src/server/index.ts @@ -1,13 +1,11 @@ import { logger, type Character, - type IAgentRuntime, - type Provider, - type Route + type IAgentRuntime } from "@elizaos/core"; import bodyParser from "body-parser"; import cors from "cors"; -import express, { type Server } from "express"; +import express from "express"; import * as fs from "node:fs"; import * as path from "node:path"; @@ -27,8 +25,8 @@ export interface ServerOptions { export class AgentServer { public app: express.Application; private agents: Map; - private server: any; // Change back to any since Server type isn't exported - public startAgent!: (character: Character) => Promise; // Add ! to indicate it will be assigned later + private server: any; + public startAgent!: (character: Character) => Promise; public loadCharacterTryPath!: (characterPath: string) => Promise; public jsonToCharacter!: (character: unknown) => Promise; @@ -208,6 +206,7 @@ export class AgentServer { } } + public async stop() { if (this.server) { this.server.close(() => { diff --git a/packages/agent/src/single-agent/character.ts b/packages/agent/src/single-agent/character.ts index 77cb36f2f14..f77707d5ee3 100644 --- a/packages/agent/src/single-agent/character.ts +++ b/packages/agent/src/single-agent/character.ts @@ -9,8 +9,8 @@ export const defaultCharacter: Character = { username: "eliza", plugins: [ // "@elizaos/plugin-anthropic", - "@elizaos/plugin-openai", - // "@elizaos/plugin-local-ai", + // "@elizaos/plugin-openai", + "@elizaos/plugin-local-ai", // "@elizaos/plugin-elevenlabs", // "@elizaos/plugin-discord", "@elizaos/plugin-node", From 343eabc19075e5c0f83f02f057ea89e232d129ea Mon Sep 17 00:00:00 2001 From: Sayo Date: Tue, 25 Feb 2025 20:06:12 +0530 Subject: [PATCH 12/13] fix direct client and bring back message routes --- packages/agent/src/index.ts | 7 +- packages/agent/src/server/api/agent.ts | 373 ++++++++++++++++++++++++- packages/client/src/lib/api.ts | 32 ++- 3 files changed, 394 insertions(+), 18 deletions(-) diff --git a/packages/agent/src/index.ts b/packages/agent/src/index.ts index 404da9ec305..c1ad070ccd7 100644 --- a/packages/agent/src/index.ts +++ b/packages/agent/src/index.ts @@ -28,9 +28,9 @@ import { import { defaultCharacter } from "./single-agent/character.ts"; import { startScenario } from "./swarm/scenario.ts"; -import swarm from "./swarm/index"; -import * as path from "node:path"; import * as fs from "node:fs"; +import * as path from "node:path"; +import swarm from "./swarm/index"; export const wait = (minTime = 1000, maxTime = 3000) => { const waitTime = @@ -235,8 +235,7 @@ const startAgents = async () => { let serverPort = Number.parseInt(settings.SERVER_PORT || "3000"); const args = parseArguments(); const charactersArg = args.characters || args.character; - const characters: Character[] = []; - + // Add this before creating the AgentServer const dataDir = path.join(process.cwd(), "data"); try { diff --git a/packages/agent/src/server/api/agent.ts b/packages/agent/src/server/api/agent.ts index d228c08464d..8edc98ef2c3 100644 --- a/packages/agent/src/server/api/agent.ts +++ b/packages/agent/src/server/api/agent.ts @@ -1,9 +1,22 @@ import express from 'express'; +import type { Character, IAgentRuntime, Media } from '@elizaos/core'; +import { ChannelType, composeContext, generateMessageResponse, logger, ModelClass, stringToUuid, validateCharacterConfig } from '@elizaos/core'; +import fs from 'node:fs'; import type { AgentServer } from '..'; -import type { Character, IAgentRuntime } from '@elizaos/core'; -import { logger, validateCharacterConfig } from '@elizaos/core'; import { validateUUIDParams } from './api-utils'; +import type { Content, Memory } from '@elizaos/core'; +import path from 'node:path'; +import { messageHandlerTemplate } from '../helper'; +import { upload } from '../loader'; + +interface CustomRequest extends express.Request { + file?: Express.Multer.File; + params: { + agentId: string; + }; +} + export function agentRouter( agents: Map, directClient: AgentServer @@ -63,6 +76,168 @@ export function agentRouter( } }); + router.post('/:agentId/message', async (req: CustomRequest, res) => { + logger.info("[MESSAGE ENDPOINT] **ROUTE HIT** - Entering /message endpoint"); + + const { agentId } = validateUUIDParams(req?.params, res) ?? { + agentId: null, + }; + if (!agentId) return; + + // Add logging to debug the request body + logger.info(`[MESSAGE ENDPOINT] Raw body: ${JSON.stringify(req.body)}`); + + const text = req.body?.text?.trim(); + logger.info(`[MESSAGE ENDPOINT] Parsed text: ${text}`); + + // Move the text validation check here, before other processing + if (!text) { + res.status(400).json({ error: "Text message is required" }); + return; + } + + const roomId = stringToUuid(req.body.roomId ?? `default-room-${agentId}`); + const userId = stringToUuid(req.body.userId ?? "user"); + + let runtime = agents.get(agentId); + + // if runtime is null, look for runtime with the same name + if (!runtime) { + runtime = Array.from(agents.values()).find( + (a) => + a.character.name.toLowerCase() === + agentId.toLowerCase() + ); + } + + if (!runtime) { + res.status(404).json({ error: 'Agent not found' }); + return; + } + + logger.info(`[MESSAGE ENDPOINT] Runtime found: ${runtime?.character?.name}`); + + try { + await runtime.ensureConnection({ + userId, + roomId, + userName: req.body.userName, + userScreenName: req.body.name, + source: "direct", + type: ChannelType.API, + }); + + logger.info(`[MESSAGE ENDPOINT] req.body: ${JSON.stringify(req.body)}`); + + const messageId = stringToUuid(Date.now().toString()); + + const attachments: Media[] = []; + if (req.file) { + const filePath = path.join( + process.cwd(), + "data", + "uploads", + req.file.filename + ); + attachments.push({ + id: Date.now().toString(), + url: filePath, + title: req.file.originalname, + source: "direct", + description: `Uploaded file: ${req.file.originalname}`, + text: "", + contentType: req.file.mimetype, + }); + } + + const content: Content = { + text, + attachments, + source: "direct", + inReplyTo: undefined, + }; + + const userMessage = { + content, + userId, + roomId, + agentId: runtime.agentId, + }; + + const memory: Memory = { + id: stringToUuid(`${messageId}-${userId}`), + ...userMessage, + agentId: runtime.agentId, + userId, + roomId, + content, + createdAt: Date.now(), + }; + + await runtime.messageManager.addEmbeddingToMemory(memory); + await runtime.messageManager.createMemory(memory); + + let state = await runtime.composeState(userMessage, { + agentName: runtime.character.name, + }); + + const context = composeContext({ + state, + template: messageHandlerTemplate, + }); + + logger.info("[MESSAGE ENDPOINT] Before generateMessageResponse"); + + const response = await generateMessageResponse({ + runtime: runtime, + context, + modelClass: ModelClass.TEXT_LARGE, + }); + + logger.info(`[MESSAGE ENDPOINT] After generateMessageResponse, response: ${response}`); + + if (!response) { + res.status(500).json({ + error: "No response from generateMessageResponse" + }); + return; + } + + // save response to memory + const responseMessage: Memory = { + id: stringToUuid(`${messageId}-${runtime.agentId}`), + ...userMessage, + userId: runtime.agentId, + content: response, + createdAt: Date.now(), + }; + + await runtime.messageManager.createMemory(responseMessage); + + state = await runtime.updateRecentMessageState(state); + + const replyHandler = async (message: Content) => { + res.json([message]); + return [memory]; + } + + await runtime.processActions( + memory, + [responseMessage], + state, + replyHandler + ); + + await runtime.evaluate(memory, state); + } catch (error) { + logger.error("Error processing message:", error); + res.status(500).json({ + error: "Error processing message", + details: error.message + }); + } + }); + router.post('/:agentId/set', async (req, res) => { const { agentId } = validateUUIDParams(req.params, res) ?? { agentId: null, @@ -256,5 +431,197 @@ export function agentRouter( } }); + router.post('/:agentId/whisper', upload.single('file'), async (req: CustomRequest, res: express.Response) => { + const audioFile = req.file; + const agentId = req.params.agentId; + + if (!audioFile) { + res.status(400).send("No audio file provided"); + return; + } + + let runtime = agents.get(agentId); + + if (!runtime) { + runtime = Array.from(agents.values()).find( + (a) => a.character.name.toLowerCase() === agentId.toLowerCase() + ); + } + + if (!runtime) { + res.status(404).send("Agent not found"); + return; + } + + const audioBuffer = fs.readFileSync(audioFile.path); + const transcription = await runtime.useModel(ModelClass.TRANSCRIPTION, audioBuffer); + + res.json({text: transcription}); + }); + + router.post('/:agentId/speak', async (req, res) => { + const { agentId } = validateUUIDParams(req.params, res) ?? { agentId: null }; + if (!agentId) return; + + const { text, roomId: rawRoomId, userId: rawUserId } = req.body; + const roomId = stringToUuid(rawRoomId ?? `default-room-${agentId}`); + const userId = stringToUuid(rawUserId ?? "user"); + + if (!text) { + res.status(400).send("No text provided"); + return; + } + + let runtime = agents.get(agentId); + if (!runtime) { + runtime = Array.from(agents.values()).find( + (a) => a.character.name.toLowerCase() === agentId.toLowerCase() + ); + } + + if (!runtime) { + res.status(404).send("Agent not found"); + return; + } + + try { + // Process message through agent + await runtime.ensureConnection({ + userId, + roomId, + userName: req.body.userName, + userScreenName: req.body.name, + source: "direct", + type: ChannelType.API, + }); + + const messageId = stringToUuid(Date.now().toString()); + + const content: Content = { + text, + attachments: [], + source: "direct", + inReplyTo: undefined, + }; + + const userMessage = { + content, + userId, + roomId, + agentId: runtime.agentId, + }; + + const memory: Memory = { + id: messageId, + agentId: runtime.agentId, + userId, + roomId, + content, + createdAt: Date.now(), + }; + + await runtime.messageManager.createMemory(memory); + + const state = await runtime.composeState(userMessage, { + agentName: runtime.character.name, + }); + + + const context = composeContext({ + state, + template: messageHandlerTemplate, + }); + + const response = await runtime.useModel(ModelClass.TEXT_LARGE, { + messages: [{ + role: 'system', + content: messageHandlerTemplate + }, { + role: 'user', + content: context + }] + }); + + // save response to memory + const responseMessage = { + ...userMessage, + userId: runtime.agentId, + content: response, + }; + + await runtime.messageManager.createMemory(responseMessage); + + if (!response) { + res.status(500).send( + "No response from generateMessageResponse" + ); + return; + } + + await runtime.evaluate(memory, state); + + const _result = await runtime.processActions( + memory, + [responseMessage], + state, + async () => { + return [memory]; + } + ); + + const speechResponse = await runtime.useModel(ModelClass.TEXT_TO_SPEECH, response.text); + const audioBuffer = await speechResponse.arrayBuffer(); + + res.set({ + "Content-Type": "audio/mpeg", + "Transfer-Encoding": "chunked", + }); + + res.send(Buffer.from(audioBuffer)); + } catch (error) { + logger.error("Error processing message or generating speech:", error); + res.status(500).json({ + error: "Error processing message or generating speech", + details: error.message, + }); + } + }); + + router.post('/:agentId/tts', async (req, res) => { + const { agentId } = validateUUIDParams(req.params, res) ?? { agentId: null }; + if (!agentId) return; + + const { text } = req.body; + if (!text) { + res.status(400).send("No text provided"); + return; + } + + const runtime = agents.get(agentId); + if (!runtime) { + res.status(404).send("Agent not found"); + return; + } + + try { + const speechResponse = await runtime.useModel(ModelClass.TEXT_TO_SPEECH, text); + const audioBuffer = await speechResponse.arrayBuffer(); + + res.set({ + "Content-Type": "audio/mpeg", + "Transfer-Encoding": "chunked", + }); + + res.send(Buffer.from(audioBuffer)); + } catch (error) { + logger.error("Error generating speech:", error); + res.status(500).json({ + error: "Error generating speech", + details: error.message, + }); + } + }); + return router; -} \ No newline at end of file +} + diff --git a/packages/client/src/lib/api.ts b/packages/client/src/lib/api.ts index f780866875e..577e8920957 100644 --- a/packages/client/src/lib/api.ts +++ b/packages/client/src/lib/api.ts @@ -69,25 +69,35 @@ export const apiClient = { message: string, selectedFile?: File | null ) => { - const formData = new FormData(); - formData.append("text", message); - formData.append("user", "user"); - if (selectedFile) { + // Use FormData only when there's a file + const formData = new FormData(); + formData.append("text", message); + formData.append("user", "user"); formData.append("file", selectedFile); + + return fetcher({ + url: `/agents/${agentId}/message`, + method: "POST", + body: formData, + }); } - return fetcher({ - url: `/${agentId}/message`, - method: "POST", - body: formData, - }); + // Use JSON when there's no file + return fetcher({ + url: `/agents/${agentId}/message`, + method: "POST", + body: { + text: message, + user: "user" + }, + }); }, getAgents: () => fetcher({ url: "/agents" }), getAgent: (agentId: string): Promise<{ id: UUID; character: Character }> => fetcher({ url: `/agents/${agentId}` }), tts: (agentId: string, text: string) => fetcher({ - url: `/${agentId}/tts`, + url: `/agents/${agentId}/tts`, method: "POST", body: { text, @@ -102,7 +112,7 @@ export const apiClient = { const formData = new FormData(); formData.append("file", audioBlob, "recording.wav"); return fetcher({ - url: `/${agentId}/whisper`, + url: `/agents/${agentId}/whisper`, method: "POST", body: formData, }); From cd0a33ee6296fddec4fa0d974b3f3908a891c9a0 Mon Sep 17 00:00:00 2001 From: Ting Chien Meng Date: Tue, 25 Feb 2025 23:30:51 +0800 Subject: [PATCH 13/13] set up dimension before starting client --- packages/core/src/runtime.ts | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/packages/core/src/runtime.ts b/packages/core/src/runtime.ts index 256ba09e603..b1675546c7b 100644 --- a/packages/core/src/runtime.ts +++ b/packages/core/src/runtime.ts @@ -548,6 +548,15 @@ export class AgentRuntime implements IAgentRuntime { ); await this.processCharacterKnowledge(stringKnowledge); } + + // Check if TEXT_EMBEDDING model is registered + const embeddingModel = this.getModel(ModelClass.TEXT_EMBEDDING); + if (!embeddingModel) { + logger.warn(`[AgentRuntime][${this.character.name}] No TEXT_EMBEDDING model registered. Skipping embedding dimension setup.`); + } else { + // Only run ensureEmbeddingDimension if we have an embedding model + await this.ensureEmbeddingDimension(); + } // Initialize services if (this.services) { @@ -563,15 +572,6 @@ export class AgentRuntime implements IAgentRuntime { this.registerClient(clientInterface.name, startedClient); }) ); - - // Check if TEXT_EMBEDDING model is registered - const embeddingModel = this.getModel(ModelClass.TEXT_EMBEDDING); - if (!embeddingModel) { - logger.warn(`[AgentRuntime][${this.character.name}] No TEXT_EMBEDDING model registered. Skipping embedding dimension setup.`); - } else { - // Only run ensureEmbeddingDimension if we have an embedding model - await this.ensureEmbeddingDimension(); - } } async ensureAgentExists() {