Skip to content

Commit

Permalink
Refactors Message into its own file (#1120)
Browse files Browse the repository at this point in the history
  • Loading branch information
mbleigh authored Oct 28, 2024
1 parent 3b4cbd4 commit dac168b
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 100 deletions.
107 changes: 13 additions & 94 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema';
import { DocumentData } from './document.js';
import { extractJson } from './extract.js';
import { generateHelper, GenerateUtilParamSchema } from './generateAction.js';
import { Message } from './message.js';
import {
GenerateRequest,
GenerateResponseChunkData,
Expand All @@ -41,88 +42,10 @@ import {
Part,
ToolDefinition,
ToolRequestPart,
ToolResponsePart,
} from './model.js';
import { ExecutablePrompt } from './prompt.js';
import { resolveTools, ToolArgument, toToolDefinition } from './tool.js';

/**
* Message represents a single role's contribution to a generation. Each message
* can contain multiple parts (for example text and an image), and each generation
* can contain multiple messages.
*/
export class Message<T = unknown> implements MessageData {
role: MessageData['role'];
content: Part[];

constructor(message: MessageData) {
this.role = message.role;
this.content = message.content;
}

/**
* If a message contains a `data` part, it is returned. Otherwise, the `output()`
* method extracts the first valid JSON object or array from the text contained in
* the message and returns it.
*
* @returns The structured output contained in the message.
*/
get output(): T {
return this.data || extractJson<T>(this.text);
}

toolResponseParts(): ToolResponsePart[] {
const res = this.content.filter((part) => !!part.toolResponse);
return res as ToolResponsePart[];
}

/**
* Concatenates all `text` parts present in the message with no delimiter.
* @returns A string of all concatenated text parts.
*/
get text(): string {
return this.content.map((part) => part.text || '').join('');
}

/**
* Returns the first media part detected in the message. Useful for extracting
* (for example) an image from a generation expected to create one.
* @returns The first detected `media` part in the message.
*/
get media(): { url: string; contentType?: string } | null {
return this.content.find((part) => part.media)?.media || null;
}

/**
* Returns the first detected `data` part of a message.
* @returns The first `data` part detected in the message (if any).
*/
get data(): T | null {
return this.content.find((part) => part.data)?.data as T | null;
}

/**
* Returns all tool request found in this message.
* @returns Array of all tool request found in this message.
*/
get toolRequests(): ToolRequestPart[] {
return this.content.filter(
(part) => !!part.toolRequest
) as ToolRequestPart[];
}

/**
* Converts the Message to a plain JS object.
* @returns Plain JS object representing the data contained in the message.
*/
toJSON(): MessageData {
return {
role: this.role,
content: [...this.content],
};
}
}

/**
* GenerateResponse is the result from a `generate()` call and contains one or
* more generated candidate messages.
Expand Down Expand Up @@ -377,29 +300,25 @@ export class GenerateResponseChunk<T = unknown>
}
}

export function normalizePart(input: string | Part | Part[]): Part[] {
if (typeof input === 'string') {
return [{ text: input }];
} else if (Array.isArray(input)) {
return input;
} else {
return [input];
}
}

export async function toGenerateRequest(
registry: Registry,
options: GenerateOptions
): Promise<GenerateRequest> {
const messages: MessageData[] = [];
if (options.system) {
messages.push({ role: 'system', content: normalizePart(options.system) });
messages.push({
role: 'system',
content: Message.parseContent(options.system),
});
}
if (options.messages) {
messages.push(...options.messages);
messages.push(...options.messages.map((m) => Message.parseData(m)));
}
if (options.prompt) {
messages.push({ role: 'user', content: normalizePart(options.prompt) });
messages.push({
role: 'user',
content: Message.parseContent(options.prompt),
});
}
if (messages.length === 0) {
throw new Error('at least one message is required in generate request');
Expand Down Expand Up @@ -443,7 +362,7 @@ export interface GenerateOptions<
/** Retrieved documents to be used as context for this generation. */
docs?: DocumentData[];
/** Conversation messages (history) for multi-turn prompting when supported by the underlying model. */
messages?: MessageData[];
messages?: (MessageData & { content: Part[] | string | (string | Part)[] })[];
/** List of registered tool names or actions to treat as a tool for this generation if supported by the underlying model. */
tools?: ToolArgument[];
/** Configuration for the generation request. */
Expand Down Expand Up @@ -585,7 +504,7 @@ export async function generate<
if (resolvedOptions.system) {
messages.push({
role: 'system',
content: normalizePart(resolvedOptions.system),
content: Message.parseContent(resolvedOptions.system),
});
}
if (resolvedOptions.messages) {
Expand All @@ -594,7 +513,7 @@ export async function generate<
if (resolvedOptions.prompt) {
messages.push({
role: 'user',
content: normalizePart(resolvedOptions.prompt),
content: Message.parseContent(resolvedOptions.prompt),
});
}

Expand Down
3 changes: 1 addition & 2 deletions js/ai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,15 @@ export {
GenerateResponse,
GenerationBlockedError,
GenerationResponseError,
Message,
generate,
generateStream,
normalizePart,
tagAsPreamble,
toGenerateRequest,
type GenerateOptions,
type GenerateStreamOptions,
type GenerateStreamResponse,
} from './generate.js';
export { Message } from './message.js';
export {
GenerationCommonConfigSchema,
MessageSchema,
Expand Down
131 changes: 131 additions & 0 deletions js/ai/src/message.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { extractJson } from './extract';
import { MessageData, Part, ToolRequestPart, ToolResponsePart } from './model';

/**
* Message represents a single role's contribution to a generation. Each message
* can contain multiple parts (for example text and an image), and each generation
* can contain multiple messages.
*/
export class Message<T = unknown> implements MessageData {
role: MessageData['role'];
content: Part[];
metadata?: Record<string, any>;

static parseData(
lenientMessage:
| string
| (MessageData & { content: string | Part | Part[]; role: string })
| MessageData,
defaultRole: MessageData['role'] = 'user'
): MessageData {
if (typeof lenientMessage === 'string') {
return { role: defaultRole, content: [{ text: lenientMessage }] };
}
return {
...lenientMessage,
content: Message.parseContent(lenientMessage.content),
};
}

static parse(
lenientMessage: string | (MessageData & { content: string }) | MessageData
): Message {
return new Message(Message.parseData(lenientMessage));
}

static parseContent(lenientPart: string | Part | (string | Part)[]): Part[] {
if (typeof lenientPart === 'string') {
return [{ text: lenientPart }];
} else if (Array.isArray(lenientPart)) {
return lenientPart.map((p) => (typeof p === 'string' ? { text: p } : p));
} else {
return [lenientPart];
}
}

constructor(message: MessageData) {
this.role = message.role;
this.content = message.content;
this.metadata = message.metadata;
}

/**
* If a message contains a `data` part, it is returned. Otherwise, the `output()`
* method extracts the first valid JSON object or array from the text contained in
* the message and returns it.
*
* @returns The structured output contained in the message.
*/
get output(): T {
return this.data || extractJson<T>(this.text);
}

toolResponseParts(): ToolResponsePart[] {
const res = this.content.filter((part) => !!part.toolResponse);
return res as ToolResponsePart[];
}

/**
* Concatenates all `text` parts present in the message with no delimiter.
* @returns A string of all concatenated text parts.
*/
get text(): string {
return this.content.map((part) => part.text || '').join('');
}

/**
* Returns the first media part detected in the message. Useful for extracting
* (for example) an image from a generation expected to create one.
* @returns The first detected `media` part in the message.
*/
get media(): { url: string; contentType?: string } | null {
return this.content.find((part) => part.media)?.media || null;
}

/**
* Returns the first detected `data` part of a message.
* @returns The first `data` part detected in the message (if any).
*/
get data(): T | null {
return this.content.find((part) => part.data)?.data as T | null;
}

/**
* Returns all tool request found in this message.
* @returns Array of all tool request found in this message.
*/
get toolRequests(): ToolRequestPart[] {
return this.content.filter(
(part) => !!part.toolRequest
) as ToolRequestPart[];
}

/**
* Converts the Message to a plain JS object.
* @returns Plain JS object representing the data contained in the message.
*/
toJSON(): MessageData {
let out: MessageData = {
role: this.role,
content: [...this.content],
};
if (this.metadata) out.metadata = this.metadata;
return out;
}
}
2 changes: 1 addition & 1 deletion js/ai/tests/generate/generate_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ import {
GenerateResponseChunk,
GenerationBlockedError,
GenerationResponseError,
Message,
generate,
toGenerateRequest,
} from '../../src/generate.js';
import { Message } from '../../src/message.js';
import {
GenerateRequest,
GenerateResponseChunkData,
Expand Down
55 changes: 55 additions & 0 deletions js/ai/tests/message/message_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import assert from 'node:assert';
import { describe, it } from 'node:test';
import { Message } from '../../src/message';

describe('Message', () => {
describe('.parseData()', () => {
const testCases = [
{
desc: 'convert string to user message',
input: 'i am a user message',
want: { role: 'user', content: [{ text: 'i am a user message' }] },
},
{
desc: 'convert string content to Part[] content',
input: {
role: 'system',
content: 'i am a system message',
metadata: { extra: true },
},
want: {
role: 'system',
content: [{ text: 'i am a system message' }],
metadata: { extra: true },
},
},
{
desc: 'leave valid MessageData alone',
input: { role: 'model', content: [{ text: 'i am a model message' }] },
want: { role: 'model', content: [{ text: 'i am a model message' }] },
},
];

for (const t of testCases) {
it(t.desc, () => {
assert.deepStrictEqual(Message.parseData(t.input as any), t.want);
});
}
});
});
4 changes: 2 additions & 2 deletions js/genkit/src/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import {
GenerateOptions,
Message,
MessageData,
normalizePart,
tagAsPreamble,
} from '@genkit-ai/ai';
import { z } from '@genkit-ai/core';
Expand Down Expand Up @@ -186,7 +186,7 @@ export class Session<S = any> {
if (baseOptions.system) {
messages.push({
role: 'system',
content: normalizePart(baseOptions.system),
content: Message.parseContent(baseOptions.system),
});
}
delete baseOptions.system;
Expand Down
Loading

0 comments on commit dac168b

Please sign in to comment.