diff --git a/packages/core/src/generation.ts b/packages/core/src/generation.ts index b70c6dd69f4..19d8035ea26 100644 --- a/packages/core/src/generation.ts +++ b/packages/core/src/generation.ts @@ -783,11 +783,6 @@ export const generateImage = async ( data?: string[]; error?: any; }> => { - const { prompt, width, height } = data; - let { count } = data; - if (!count) { - count = 1; - } const model = getModel(runtime.imageModelProvider, ModelClass.IMAGE); const modelSettings = models[runtime.imageModelProvider].imageSettings; @@ -843,16 +838,18 @@ export const generateImage = async ( const imageURL = await response.json(); return { success: true, data: [imageURL] }; } else if ( + // TODO: Fix LLAMACLOUD -> Together? runtime.imageModelProvider === ModelProviderName.LLAMACLOUD ) { const together = new Together({ apiKey: apiKey as string }); + // Fix: steps 4 is for schnell; 28 is for dev. const response = await together.images.create({ model: "black-forest-labs/FLUX.1-schnell", - prompt, - width, - height, + data.prompt, + data.width, + data.height, steps: modelSettings?.steps ?? 4, - n: count, + n: data.count, }); const urls: string[] = []; for (let i = 0; i < response.data.length; i++) { @@ -879,11 +876,11 @@ export const generateImage = async ( // Prepare the input parameters according to their schema const input = { - prompt: prompt, + prompt: data.prompt, image_size: "square" as const, num_inference_steps: modelSettings?.steps ?? 50, - guidance_scale: 3.5, - num_images: count, + guidance_scale: data.guidanceScale || 3.5, + num_images: data.count, enable_safety_checker: true, output_format: "png" as const, seed: data.seed ?? 6252023, @@ -933,9 +930,9 @@ export const generateImage = async ( const openai = new OpenAI({ apiKey: apiKey as string }); const response = await openai.images.generate({ model, - prompt, + data.prompt, size: targetSize as "1024x1024" | "1792x1024" | "1024x1792", - n: count, + n: data.count, response_format: "b64_json", }); const base64s = response.data.map( diff --git a/packages/plugin-image-generation/src/enviroment.ts b/packages/plugin-image-generation/src/environment.ts similarity index 100% rename from packages/plugin-image-generation/src/enviroment.ts rename to packages/plugin-image-generation/src/environment.ts diff --git a/packages/plugin-image-generation/src/index.ts b/packages/plugin-image-generation/src/index.ts index 8a8ec4102a4..897572f7f9a 100644 --- a/packages/plugin-image-generation/src/index.ts +++ b/packages/plugin-image-generation/src/index.ts @@ -11,7 +11,7 @@ import { generateImage } from "@ai16z/eliza"; import fs from "fs"; import path from "path"; -import { validateImageGenConfig } from "./enviroment"; +import { validateImageGenConfig } from "./environment"; export function saveBase64Image(base64Data: string, filename: string): string { // Create generatedImages directory if it doesn't exist @@ -97,7 +97,17 @@ const imageGeneration: Action = { runtime: IAgentRuntime, message: Memory, state: State, - options: any, + options: { + width?: number; + height?: number; + count?: number; + negativePrompt?: string; + numIterations?: number; + guidanceScale?: number; + seed?: number; + modelId?: string; + jobId?: string; + }, callback: HandlerCallback ) => { elizaLogger.log("Composing state for message:", message); @@ -116,9 +126,15 @@ const imageGeneration: Action = { const images = await generateImage( { prompt: imagePrompt, - width: 1024, - height: 1024, - count: 1, + ...(options.width !== undefined ? { width: options.width || 1024 } : {}), + ...(options.height !== undefined ? { height: options.height || 1024 } : {}), + ...(options.count !== undefined ? { count: options.count || 1 } : {}), + ...(options.negativePrompt !== undefined ? { negativePrompt: options.negativePrompt } : {}), + ...(options.numIterations !== undefined ? { numIterations: options.numIterations } : {}), + ...(options.guidanceScale !== undefined ? { guidanceScale: options.guidanceScale } : {}), + ...(options.seed !== undefined ? { seed: options.seed } : {}), + ...(options.modelId !== undefined ? { modelId: options.modelId } : {}), + ...(options.jobId !== undefined ? { jobId: options.jobId } : {}) }, runtime );