Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactors Message into its own file #1120

Merged
merged 4 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 13 additions & 94 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema';
import { DocumentData } from './document.js';
import { extractJson } from './extract.js';
import { generateHelper, GenerateUtilParamSchema } from './generateAction.js';
import { Message } from './message.js';
import {
GenerateRequest,
GenerateResponseChunkData,
Expand All @@ -41,88 +42,10 @@ import {
Part,
ToolDefinition,
ToolRequestPart,
ToolResponsePart,
} from './model.js';
import { ExecutablePrompt } from './prompt.js';
import { resolveTools, ToolArgument, toToolDefinition } from './tool.js';

/**
* Message represents a single role's contribution to a generation. Each message
* can contain multiple parts (for example text and an image), and each generation
* can contain multiple messages.
*/
export class Message<T = unknown> implements MessageData {
role: MessageData['role'];
content: Part[];

constructor(message: MessageData) {
this.role = message.role;
this.content = message.content;
}

/**
* If a message contains a `data` part, it is returned. Otherwise, the `output()`
* method extracts the first valid JSON object or array from the text contained in
* the message and returns it.
*
* @returns The structured output contained in the message.
*/
get output(): T {
return this.data || extractJson<T>(this.text);
}

toolResponseParts(): ToolResponsePart[] {
const res = this.content.filter((part) => !!part.toolResponse);
return res as ToolResponsePart[];
}

/**
* Concatenates all `text` parts present in the message with no delimiter.
* @returns A string of all concatenated text parts.
*/
get text(): string {
return this.content.map((part) => part.text || '').join('');
}

/**
* Returns the first media part detected in the message. Useful for extracting
* (for example) an image from a generation expected to create one.
* @returns The first detected `media` part in the message.
*/
get media(): { url: string; contentType?: string } | null {
return this.content.find((part) => part.media)?.media || null;
}

/**
* Returns the first detected `data` part of a message.
* @returns The first `data` part detected in the message (if any).
*/
get data(): T | null {
return this.content.find((part) => part.data)?.data as T | null;
}

/**
* Returns all tool request found in this message.
* @returns Array of all tool request found in this message.
*/
get toolRequests(): ToolRequestPart[] {
return this.content.filter(
(part) => !!part.toolRequest
) as ToolRequestPart[];
}

/**
* Converts the Message to a plain JS object.
* @returns Plain JS object representing the data contained in the message.
*/
toJSON(): MessageData {
return {
role: this.role,
content: [...this.content],
};
}
}

/**
* GenerateResponse is the result from a `generate()` call and contains one or
* more generated candidate messages.
Expand Down Expand Up @@ -361,29 +284,25 @@ export class GenerateResponseChunk<T = unknown>
}
}

export function normalizePart(input: string | Part | Part[]): Part[] {
if (typeof input === 'string') {
return [{ text: input }];
} else if (Array.isArray(input)) {
return input;
} else {
return [input];
}
}

export async function toGenerateRequest(
registry: Registry,
options: GenerateOptions
): Promise<GenerateRequest> {
const messages: MessageData[] = [];
if (options.system) {
messages.push({ role: 'system', content: normalizePart(options.system) });
messages.push({
role: 'system',
content: Message.parseContent(options.system),
});
}
if (options.messages) {
messages.push(...options.messages);
messages.push(...options.messages.map((m) => Message.parseData(m)));
}
if (options.prompt) {
messages.push({ role: 'user', content: normalizePart(options.prompt) });
messages.push({
role: 'user',
content: Message.parseContent(options.prompt),
});
}
if (messages.length === 0) {
throw new Error('at least one message is required in generate request');
Expand Down Expand Up @@ -427,7 +346,7 @@ export interface GenerateOptions<
/** Retrieved documents to be used as context for this generation. */
docs?: DocumentData[];
/** Conversation messages (history) for multi-turn prompting when supported by the underlying model. */
messages?: MessageData[];
messages?: (MessageData & { content: Part[] | string | (string | Part)[] })[];
/** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */
tools?: ToolArgument[];
/** Configuration for the generation request. */
Expand Down Expand Up @@ -569,7 +488,7 @@ export async function generate<
if (resolvedOptions.system) {
messages.push({
role: 'system',
content: normalizePart(resolvedOptions.system),
content: Message.parseContent(resolvedOptions.system),
});
}
if (resolvedOptions.messages) {
Expand All @@ -578,7 +497,7 @@ export async function generate<
if (resolvedOptions.prompt) {
messages.push({
role: 'user',
content: normalizePart(resolvedOptions.prompt),
content: Message.parseContent(resolvedOptions.prompt),
});
}

Expand Down
3 changes: 1 addition & 2 deletions js/ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,15 @@ export {
GenerateResponse,
GenerationBlockedError,
GenerationResponseError,
Message,
generate,
generateStream,
normalizePart,
tagAsPreamble,
toGenerateRequest,
type GenerateOptions,
type GenerateStreamOptions,
type GenerateStreamResponse,
} from './generate.js';
export { Message } from './message.js';
export {
GenerationCommonConfigSchema,
MessageSchema,
Expand Down
131 changes: 131 additions & 0 deletions js/ai/src/message.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { extractJson } from './extract';
import { MessageData, Part, ToolRequestPart, ToolResponsePart } from './model';

/**
* Message represents a single role's contribution to a generation. Each message
* can contain multiple parts (for example text and an image), and each generation
* can contain multiple messages.
*/
export class Message<T = unknown> implements MessageData {
role: MessageData['role'];
content: Part[];
metadata?: Record<string, any>;

static parseData(
lenientMessage:
| string
| (MessageData & { content: string | Part | Part[]; role: string })
| MessageData,
defaultRole: MessageData['role'] = 'user'
): MessageData {
if (typeof lenientMessage === 'string') {
return { role: defaultRole, content: [{ text: lenientMessage }] };
}
return {
...lenientMessage,
content: Message.parseContent(lenientMessage.content),
};
}

static parse(
lenientMessage: string | (MessageData & { content: string }) | MessageData
): Message {
return new Message(Message.parseData(lenientMessage));
}

static parseContent(lenientPart: string | Part | (string | Part)[]): Part[] {
if (typeof lenientPart === 'string') {
return [{ text: lenientPart }];
} else if (Array.isArray(lenientPart)) {
return lenientPart.map((p) => (typeof p === 'string' ? { text: p } : p));
} else {
return [lenientPart];
}
}

constructor(message: MessageData) {
this.role = message.role;
this.content = message.content;
this.metadata = message.metadata;
}

/**
* If a message contains a `data` part, it is returned. Otherwise, the `output()`
* method extracts the first valid JSON object or array from the text contained in
* the message and returns it.
*
* @returns The structured output contained in the message.
*/
get output(): T {
return this.data || extractJson<T>(this.text);
}

toolResponseParts(): ToolResponsePart[] {
const res = this.content.filter((part) => !!part.toolResponse);
return res as ToolResponsePart[];
}

/**
* Concatenates all `text` parts present in the message with no delimiter.
* @returns A string of all concatenated text parts.
*/
get text(): string {
return this.content.map((part) => part.text || '').join('');
}

/**
* Returns the first media part detected in the message. Useful for extracting
* (for example) an image from a generation expected to create one.
* @returns The first detected `media` part in the message.
*/
get media(): { url: string; contentType?: string } | null {
return this.content.find((part) => part.media)?.media || null;
}

/**
* Returns the first detected `data` part of a message.
* @returns The first `data` part detected in the message (if any).
*/
get data(): T | null {
return this.content.find((part) => part.data)?.data as T | null;
}

/**
* Returns all tool request found in this message.
* @returns Array of all tool request found in this message.
*/
get toolRequests(): ToolRequestPart[] {
return this.content.filter(
(part) => !!part.toolRequest
) as ToolRequestPart[];
}

/**
* Converts the Message to a plain JS object.
* @returns Plain JS object representing the data contained in the message.
*/
toJSON(): MessageData {
let out: MessageData = {
role: this.role,
content: [...this.content],
};
if (this.metadata) out.metadata = this.metadata;
return out;
}
}
2 changes: 1 addition & 1 deletion js/ai/tests/generate/generate_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ import {
GenerateResponseChunk,
GenerationBlockedError,
GenerationResponseError,
Message,
generate,
toGenerateRequest,
} from '../../src/generate.js';
import { Message } from '../../src/message.js';
import {
GenerateRequest,
GenerateResponseChunkData,
Expand Down
55 changes: 55 additions & 0 deletions js/ai/tests/message/message_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import assert from 'node:assert';
import { describe, it } from 'node:test';
import { Message } from '../../src/message';

describe('Message', () => {
describe('.parseData()', () => {
const testCases = [
{
desc: 'convert string to user message',
input: 'i am a user message',
want: { role: 'user', content: [{ text: 'i am a user message' }] },
},
{
desc: 'convert string content to Part[] content',
input: {
role: 'system',
content: 'i am a system message',
metadata: { extra: true },
},
want: {
role: 'system',
content: [{ text: 'i am a system message' }],
metadata: { extra: true },
},
},
{
desc: 'leave valid MessageData alone',
input: { role: 'model', content: [{ text: 'i am a model message' }] },
want: { role: 'model', content: [{ text: 'i am a model message' }] },
},
];

for (const t of testCases) {
it(t.desc, () => {
assert.deepStrictEqual(Message.parseData(t.input as any), t.want);
});
}
});
});
4 changes: 2 additions & 2 deletions js/genkit/src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import {
GenerateOptions,
Message,
MessageData,
normalizePart,
tagAsPreamble,
} from '@genkit-ai/ai';
import { z } from '@genkit-ai/core';
Expand Down Expand Up @@ -186,7 +186,7 @@ export class Session<S = any> {
if (baseOptions.system) {
messages.push({
role: 'system',
content: normalizePart(baseOptions.system),
content: Message.parseContent(baseOptions.system),
});
}
delete baseOptions.system;
Expand Down
Loading
Loading