diff --git a/src/language-model.ts b/src/language-model.ts index 9b5477c..063e66f 100644 --- a/src/language-model.ts +++ b/src/language-model.ts @@ -6,7 +6,6 @@ import { LanguageModelV1FunctionToolCall, LanguageModelV1ImagePart, LanguageModelV1LogProbs, - LanguageModelV1Message, LanguageModelV1Prompt, LanguageModelV1StreamPart, LanguageModelV1TextPart, @@ -35,6 +34,7 @@ export interface ChromeAIChatSettings extends Record { /** * Optional. A list of unique safety settings for blocking unsafe content. + * @note this is not working yet */ safetySettings?: Array<{ category: @@ -60,13 +60,13 @@ function getStringContent( | LanguageModelV1ToolResultPart[] ): string { if (typeof content === 'string') { - return content; + return content.trim(); } else if (Array.isArray(content) && content.length > 0) { const [first] = content; if (first.type !== 'text') { throw new UnsupportedFunctionalityError({ functionality: 'toolCall' }); } - return first.text; + return first.text.trim(); } else { return ''; } @@ -122,55 +122,53 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 { return this.session; }; - private magicPrompts: Record< - LanguageModelV1CallOptions['mode']['type'], - string | null - > = { - regular: null, - 'object-grammar': null, - 'object-tool': null, - 'object-json': - 'Always response pure json string format that matches the JSON schema above, not markdown or other format!!', - }; - private roleMap: Record = { - system: 'system', - user: 'user', - tool: 'user', - assistant: 'model', - }; - private formatMessages = (options: LanguageModelV1CallOptions): string => { let prompt: LanguageModelV1Prompt = options.prompt; + debug('before format prompt:', prompt); + + let result = ''; - // When the user supplied a prompt input, we don't transform it: if ( + // When the user supplied a prompt input, we don't transform it options.inputFormat === 'prompt' && prompt.length === 1 && prompt[0].role === 'user' && prompt[0].content.length === 1 && prompt[0].content[0].type === 'text' ) { - debug('formated message:', prompt[0].content[0].text); - return prompt[0].content[0].text; - } + result += prompt[0].content[0].text; + } else { + // Use magic prompt for object-json mode + if (options.mode.type === 'object-json') { + prompt.unshift({ + role: 'system', + content: `Throughout our conversation, always start your responses with "{" and end with "}", ensuring the output is a concise JSON object and strictly avoid including any comments, notes, explanations, or examples in your output.\nFor instance, if the JSON schema is {"type":"object","properties":{"someKey":{"type":"string"}},"required":["someKey"],"additionalProperties":false,"$schema":"http://json-schema.org/draft-07/schema#"}, your response should immediately begin with "{" and strictly end with "}", following the format: {"someKey": "someValue"}.\nAdhere to this format for all queries moving forward.`, + }); + } + + for (let index = 0; index < prompt.length; index += 1) { + const { role, content } = prompt[index]; + const contentString = getStringContent(content); - // FIXME: something tricky here - if (this.magicPrompts[options.mode.type]) { - prompt = [ - { role: 'system', content: this.magicPrompts[options.mode.type]! }, - ...prompt, - ]; + switch (role) { + case 'system': + result += `${contentString}\n`; + break; + case 'assistant': + case 'tool': + result += `model\n${contentString}\n`; + break; + case 'user': + default: + result += `user\n${contentString}\n`; + break; + } + } + result += `model\n`; } - const messages = prompt - .map( - ({ role, content }) => - `${this.roleMap[role]}:\n${getStringContent(content)}` - ) - .join('\n\n'); - debug('format prompt:', prompt); - debug('formated message:', messages); - return messages + `\n`; + debug('formated message:', result); + return result; }; public doGenerate = async ( @@ -185,10 +183,17 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 { warnings?: LanguageModelV1CallWarning[]; logprobs?: LanguageModelV1LogProbs; }> => { - const session = await this.getSession({ temperature: options.temperature }); + debug('generate options:', options); + + if (['regular', 'object-json'].indexOf(options.mode.type) < 0) { + throw new UnsupportedFunctionalityError({ + functionality: `${options.mode.type} mode`, + }); + } + + const session = await this.getSession(); const message = this.formatMessages(options); const text = await session.prompt(message); - debug('generate options:', options); debug('generate result:', text); return { text, @@ -206,16 +211,26 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 { rawResponse?: { headers?: Record }; warnings?: LanguageModelV1CallWarning[]; }> => { - const session = await this.getSession({ temperature: options.temperature }); + debug('stream options:', options); + + if (['regular'].indexOf(options.mode.type) < 0) { + throw new UnsupportedFunctionalityError({ + functionality: `${options.mode.type} mode`, + }); + } + + const session = await this.getSession(); const message = this.formatMessages(options); const promptStream = session.promptStreaming(message); - debug('stream options:', options); + + let tempResult = ''; const transformStream = new TransformStream< string, LanguageModelV1StreamPart >({ transform(textDelta, controller) { controller.enqueue({ type: 'text-delta', textDelta }); + tempResult = textDelta; }, flush(controller) { controller.enqueue({ @@ -223,7 +238,8 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 { finishReason: 'stop', usage: { completionTokens: 0, promptTokens: 0 }, }); - controller.terminate(); + debug('stream result:', tempResult); + tempResult = ''; }, }); const stream = promptStream.pipeThrough(transformStream);