Skip to content

Commit

Permalink
feat: improve object generative by maigc prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
jeasonstudio committed Jul 2, 2024
1 parent ec9e334 commit 2cca55e
Showing 1 changed file with 60 additions and 44 deletions.
104 changes: 60 additions & 44 deletions src/language-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import {
LanguageModelV1FunctionToolCall,
LanguageModelV1ImagePart,
LanguageModelV1LogProbs,
LanguageModelV1Message,
LanguageModelV1Prompt,
LanguageModelV1StreamPart,
LanguageModelV1TextPart,
Expand Down Expand Up @@ -35,6 +34,7 @@ export interface ChromeAIChatSettings extends Record<string, unknown> {

/**
* Optional. A list of unique safety settings for blocking unsafe content.
* @note this is not working yet
*/
safetySettings?: Array<{
category:
Expand All @@ -60,13 +60,13 @@ function getStringContent(
| LanguageModelV1ToolResultPart[]
): string {
if (typeof content === 'string') {
return content;
return content.trim();
} else if (Array.isArray(content) && content.length > 0) {
const [first] = content;
if (first.type !== 'text') {
throw new UnsupportedFunctionalityError({ functionality: 'toolCall' });
}
return first.text;
return first.text.trim();
} else {
return '';
}
Expand Down Expand Up @@ -122,55 +122,53 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 {
return this.session;
};

private magicPrompts: Record<
LanguageModelV1CallOptions['mode']['type'],
string | null
> = {
regular: null,
'object-grammar': null,
'object-tool': null,
'object-json':
'Always response pure json string format that matches the JSON schema above, not markdown or other format!!',
};
private roleMap: Record<LanguageModelV1Message['role'], string> = {
system: 'system',
user: 'user',
tool: 'user',
assistant: 'model',
};

private formatMessages = (options: LanguageModelV1CallOptions): string => {
let prompt: LanguageModelV1Prompt = options.prompt;
debug('before format prompt:', prompt);

let result = '';

// When the user supplied a prompt input, we don't transform it:
if (
// When the user supplied a prompt input, we don't transform it
options.inputFormat === 'prompt' &&
prompt.length === 1 &&
prompt[0].role === 'user' &&
prompt[0].content.length === 1 &&
prompt[0].content[0].type === 'text'
) {
debug('formated message:', prompt[0].content[0].text);
return prompt[0].content[0].text;
}
result += prompt[0].content[0].text;
} else {
// Use magic prompt for object-json mode
if (options.mode.type === 'object-json') {
prompt.unshift({
role: 'system',
content: `Throughout our conversation, always start your responses with "{" and end with "}", ensuring the output is a concise JSON object and strictly avoid including any comments, notes, explanations, or examples in your output.\nFor instance, if the JSON schema is {"type":"object","properties":{"someKey":{"type":"string"}},"required":["someKey"],"additionalProperties":false,"$schema":"http://json-schema.org/draft-07/schema#"}, your response should immediately begin with "{" and strictly end with "}", following the format: {"someKey": "someValue"}.\nAdhere to this format for all queries moving forward.`,
});
}

for (let index = 0; index < prompt.length; index += 1) {
const { role, content } = prompt[index];
const contentString = getStringContent(content);

// FIXME: something tricky here
if (this.magicPrompts[options.mode.type]) {
prompt = [
{ role: 'system', content: this.magicPrompts[options.mode.type]! },
...prompt,
];
switch (role) {
case 'system':
result += `${contentString}\n`;
break;
case 'assistant':
case 'tool':
result += `model\n${contentString}\n`;
break;
case 'user':
default:
result += `user\n${contentString}\n`;
break;
}
}
result += `model\n`;
}

const messages = prompt
.map(
({ role, content }) =>
`${this.roleMap[role]}:\n${getStringContent(content)}`
)
.join('\n\n');
debug('format prompt:', prompt);
debug('formated message:', messages);
return messages + `\n`;
debug('formated message:', result);
return result;
};

public doGenerate = async (
Expand All @@ -185,10 +183,17 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 {
warnings?: LanguageModelV1CallWarning[];
logprobs?: LanguageModelV1LogProbs;
}> => {
const session = await this.getSession({ temperature: options.temperature });
debug('generate options:', options);

if (['regular', 'object-json'].indexOf(options.mode.type) < 0) {
throw new UnsupportedFunctionalityError({
functionality: `${options.mode.type} mode`,
});
}

const session = await this.getSession();
const message = this.formatMessages(options);
const text = await session.prompt(message);
debug('generate options:', options);
debug('generate result:', text);
return {
text,
Expand All @@ -206,24 +211,35 @@ export class ChromeAIChatLanguageModel implements LanguageModelV1 {
rawResponse?: { headers?: Record<string, string> };
warnings?: LanguageModelV1CallWarning[];
}> => {
const session = await this.getSession({ temperature: options.temperature });
debug('stream options:', options);

if (['regular'].indexOf(options.mode.type) < 0) {
throw new UnsupportedFunctionalityError({
functionality: `${options.mode.type} mode`,
});
}

const session = await this.getSession();
const message = this.formatMessages(options);
const promptStream = session.promptStreaming(message);
debug('stream options:', options);

let tempResult = '';
const transformStream = new TransformStream<
string,
LanguageModelV1StreamPart
>({
transform(textDelta, controller) {
controller.enqueue({ type: 'text-delta', textDelta });
tempResult = textDelta;
},
flush(controller) {
controller.enqueue({
type: 'finish',
finishReason: 'stop',
usage: { completionTokens: 0, promptTokens: 0 },
});
controller.terminate();
debug('stream result:', tempResult);
tempResult = '';
},
});
const stream = promptStream.pipeThrough(transformStream);
Expand Down

0 comments on commit 2cca55e

Please sign in to comment.