Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ai): Make response parsing extensible #14196

Merged
merged 4 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions packages/ai-chat/src/browser/ai-chat-frontend-module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import { UniversalChatAgent } from '../common/universal-chat-agent';
import { aiChatPreferences } from './ai-chat-preferences';
import { ChatAgentsVariableContribution } from '../common/chat-agents-variable-contribution';
import { FrontendChatServiceImpl } from './frontend-chat-service';
import { DefaultResponseContentMatcherProvider, DefaultResponseContentFactory, ResponseContentMatcherProvider } from '../common/response-content-matcher';

export default new ContainerModule(bind => {
bindContributionProvider(bind, Agent);
Expand All @@ -42,6 +43,11 @@ export default new ContainerModule(bind => {
bind(ChatAgentService).toService(ChatAgentServiceImpl);
bind(DefaultChatAgentId).toConstantValue({ id: OrchestratorChatAgentId });

bindContributionProvider(bind, ResponseContentMatcherProvider);
bind(DefaultResponseContentMatcherProvider).toSelf().inSingletonScope();
bind(ResponseContentMatcherProvider).toService(DefaultResponseContentMatcherProvider);
bind(DefaultResponseContentFactory).toSelf().inSingletonScope();

bind(AIVariableContribution).to(ChatAgentsVariableContribution).inSingletonScope();

bind(ChatRequestParserImpl).toSelf().inSingletonScope();
Expand Down
121 changes: 61 additions & 60 deletions packages/ai-chat/src/common/chat-agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
LanguageModel,
LanguageModelRequirement,
LanguageModelResponse,
LanguageModelStreamResponse,
PromptService,
ResolvedPromptTemplate,
ToolRequest,
Expand All @@ -37,19 +38,20 @@ import {
LanguageModelStreamResponsePart,
MessageActor,
} from '@theia/ai-core/lib/common';
import { CancellationToken, CancellationTokenSource, ILogger, isArray } from '@theia/core';
import { inject, injectable } from '@theia/core/shared/inversify';
import { CancellationToken, CancellationTokenSource, ContributionProvider, ILogger, isArray } from '@theia/core';
import { inject, injectable, named, postConstruct } from '@theia/core/shared/inversify';
import { ChatAgentService } from './chat-agent-service';
import {
ChatModel,
ChatRequestModel,
ChatRequestModelImpl,
ChatResponseContent,
CodeChatResponseContentImpl,
ErrorChatResponseContentImpl,
MarkdownChatResponseContentImpl,
ToolCallChatResponseContentImpl
} from './chat-model';
import { findEarliestMatch, parseContents } from './parse-contents';
import { DefaultResponseContentFactory, ResponseContentMatcher, ResponseContentMatcherProvider } from './response-content-matcher';

/**
* A conversation consists of a sequence of ChatMessages.
Expand Down Expand Up @@ -121,6 +123,14 @@ export abstract class AbstractChatAgent {
@inject(ILogger) protected logger: ILogger;
@inject(CommunicationRecordingService) protected recordingService: CommunicationRecordingService;
@inject(PromptService) protected promptService: PromptService;

@inject(ContributionProvider) @named(ResponseContentMatcherProvider)
protected contentMatcherProviders: ContributionProvider<ResponseContentMatcherProvider>;
protected contentMatchers: ResponseContentMatcher[] = [];

@inject(DefaultResponseContentFactory)
protected defaultContentFactory: DefaultResponseContentFactory;

constructor(
public id: string,
public languageModelRequirements: LanguageModelRequirement[],
Expand All @@ -130,6 +140,11 @@ export abstract class AbstractChatAgent {
public tags: String[] = ['Chat']) {
}

@postConstruct()
init(): void {
this.contentMatchers = this.contentMatcherProviders.getContributions().flatMap(provider => provider.matchers);
}

async invoke(request: ChatRequestModelImpl): Promise<void> {
try {
const languageModel = await this.getLanguageModel(this.defaultLanguageModelPurpose);
Expand Down Expand Up @@ -189,6 +204,14 @@ export abstract class AbstractChatAgent {
}
}

protected parseContents(text: string): ChatResponseContent[] {
return parseContents(
text,
this.contentMatchers,
this.defaultContentFactory?.create.bind(this.defaultContentFactory)
);
};

protected handleError(request: ChatRequestModelImpl, error: Error): void {
request.response.response.addContent(new ErrorChatResponseContentImpl(error));
request.response.error(error);
Expand Down Expand Up @@ -281,9 +304,8 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent {

protected override async addContentsToResponse(languageModelResponse: LanguageModelResponse, request: ChatRequestModelImpl): Promise<void> {
if (isLanguageModelTextResponse(languageModelResponse)) {
request.response.response.addContent(
new MarkdownChatResponseContentImpl(languageModelResponse.text)
);
const contents = this.parseContents(languageModelResponse.text);
request.response.response.addContents(contents);
request.response.complete();
this.recordingService.recordResponse({
agentId: this.id,
Expand All @@ -295,57 +317,7 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent {
return;
}
if (isLanguageModelStreamResponse(languageModelResponse)) {
for await (const token of languageModelResponse.stream) {
const newContents = this.parse(token, request.response.response.content);
if (isArray(newContents)) {
newContents.forEach(newContent => request.response.response.addContent(newContent));
} else {
request.response.response.addContent(newContents);
}

const lastContent = request.response.response.content.pop();
if (lastContent === undefined) {
return;
}
const text = lastContent.asString?.();
if (text === undefined) {
return;
}
let curSearchIndex = 0;
const result: ChatResponseContent[] = [];
while (curSearchIndex < text.length) {
// find start of code block: ```[language]\n<code>[\n]```
const codeStartIndex = text.indexOf('```', curSearchIndex);
if (codeStartIndex === -1) {
break;
}

// find language specifier if present
const newLineIndex = text.indexOf('\n', codeStartIndex + 3);
const language = codeStartIndex + 3 < newLineIndex ? text.substring(codeStartIndex + 3, newLineIndex) : undefined;

// find end of code block
const codeEndIndex = text.indexOf('```', codeStartIndex + 3);
if (codeEndIndex === -1) {
break;
}

// add text before code block as markdown content
result.push(new MarkdownChatResponseContentImpl(text.substring(curSearchIndex, codeStartIndex)));
// add code block as code content
const codeText = text.substring(newLineIndex + 1, codeEndIndex).trimEnd();
result.push(new CodeChatResponseContentImpl(codeText, language));
curSearchIndex = codeEndIndex + 3;
}

if (result.length > 0) {
result.forEach(r => {
request.response.response.addContent(r);
});
} else {
request.response.response.addContent(lastContent);
}
}
await this.addStreamResponse(languageModelResponse, request);
request.response.complete();
this.recordingService.recordResponse({
agentId: this.id,
Expand All @@ -366,19 +338,48 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent {
);
}

private parse(token: LanguageModelStreamResponsePart, previousContent: ChatResponseContent[]): ChatResponseContent | ChatResponseContent[] {
protected async addStreamResponse(languageModelResponse: LanguageModelStreamResponse, request: ChatRequestModelImpl): Promise<void> {
for await (const token of languageModelResponse.stream) {
const newContents = this.parse(token, request.response.response.content);
if (isArray(newContents)) {
request.response.response.addContents(newContents);
} else {
request.response.response.addContent(newContents);
}

const lastContent = request.response.response.content.pop();
if (lastContent === undefined) {
return;
}
const text = lastContent.asString?.();
if (text === undefined) {
return;
}

const result: ChatResponseContent[] = findEarliestMatch(this.contentMatchers, text) ? this.parseContents(text) : [];
if (result.length > 0) {
result.forEach(r => {
request.response.response.addContent(r);
planger marked this conversation as resolved.
Show resolved Hide resolved
});
} else {
request.response.response.addContent(lastContent);
}
}
}

protected parse(token: LanguageModelStreamResponsePart, previousContent: ChatResponseContent[]): ChatResponseContent | ChatResponseContent[] {
const content = token.content;
// eslint-disable-next-line no-null/no-null
if (content !== undefined && content !== null) {
return new MarkdownChatResponseContentImpl(content);
return this.defaultContentFactory?.create(content);
planger marked this conversation as resolved.
Show resolved Hide resolved
}
const toolCalls = token.tool_calls;
if (toolCalls !== undefined) {
const toolCallContents = toolCalls.map(toolCall =>
new ToolCallChatResponseContentImpl(toolCall.id, toolCall.function?.name, toolCall.function?.arguments, toolCall.finished, toolCall.result));
return toolCallContents;
}
return new MarkdownChatResponseContentImpl('');
return this.defaultContentFactory.create('');
}

}
4 changes: 4 additions & 0 deletions packages/ai-chat/src/common/chat-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,10 @@ class ChatResponseImpl implements ChatResponse {
return this._content;
}

addContents(contents: ChatResponseContent[]): void {
contents.forEach(c => this.addContent(c));
planger marked this conversation as resolved.
Show resolved Hide resolved
}

addContent(nextContent: ChatResponseContent): void {
// TODO: Support more complex merges affecting different content than the last, e.g. via some kind of ProcessorRegistry
// TODO: Support more of the built-in VS Code behavior, see
Expand Down
142 changes: 142 additions & 0 deletions packages/ai-chat/src/common/parse-contents.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// *****************************************************************************
// Copyright (C) 2024 EclipseSource GmbH.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// http://www.eclipse.org/legal/epl-2.0.
//
// This Source Code may also be made available under the following Secondary
// Licenses when the conditions for such availability set forth in the Eclipse
// Public License v. 2.0 are satisfied: GNU General Public License, version 2
// with the GNU Classpath Exception which is available at
// https://www.gnu.org/software/classpath/license.html.
//
// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0
// *****************************************************************************

import { expect } from 'chai';
import { ChatResponseContent, CodeChatResponseContentImpl, MarkdownChatResponseContentImpl } from './chat-model';
import { parseContents } from './parse-contents';
import { CodeContentMatcher, ResponseContentMatcher } from './response-content-matcher';

export class CommandChatResponseContentImpl implements ChatResponseContent {
constructor(public readonly command: string) { }
kind = 'command';
}

export const CommandContentMatcher: ResponseContentMatcher = {
start: /^<command>$/m,
end: /^<\/command>$/m,
contentFactory: (content: string) => {
const code = content.replace(/^<command>\n|<\/command>$/g, '');
return new CommandChatResponseContentImpl(code.trim());
}
};

describe('parseContents', () => {
it('should parse code content', () => {
const text = '```typescript\nconsole.log("Hello World");\n```';
const result = parseContents(text);
expect(result).to.deep.equal([new CodeChatResponseContentImpl('console.log("Hello World");', 'typescript')]);
});

it('should parse markdown content', () => {
const text = 'Hello **World**';
const result = parseContents(text);
expect(result).to.deep.equal([new MarkdownChatResponseContentImpl('Hello **World**')]);
});

it('should parse multiple content blocks', () => {
const text = '```typescript\nconsole.log("Hello World");\n```\nHello **World**';
const result = parseContents(text);
expect(result).to.deep.equal([
new CodeChatResponseContentImpl('console.log("Hello World");', 'typescript'),
new MarkdownChatResponseContentImpl('\nHello **World**')
]);
});

it('should parse multiple content blocks with different languages', () => {
const text = '```typescript\nconsole.log("Hello World");\n```\n```python\nprint("Hello World")\n```';
const result = parseContents(text);
expect(result).to.deep.equal([
new CodeChatResponseContentImpl('console.log("Hello World");', 'typescript'),
new CodeChatResponseContentImpl('print("Hello World")', 'python')
]);
});

it('should parse multiple content blocks with different languages and markdown', () => {
const text = '```typescript\nconsole.log("Hello World");\n```\nHello **World**\n```python\nprint("Hello World")\n```';
const result = parseContents(text);
expect(result).to.deep.equal([
new CodeChatResponseContentImpl('console.log("Hello World");', 'typescript'),
new MarkdownChatResponseContentImpl('\nHello **World**\n'),
new CodeChatResponseContentImpl('print("Hello World")', 'python')
]);
});

it('should parse content blocks with empty content', () => {
const text = '```typescript\n```\nHello **World**\n```python\nprint("Hello World")\n```';
const result = parseContents(text);
expect(result).to.deep.equal([
new CodeChatResponseContentImpl('', 'typescript'),
new MarkdownChatResponseContentImpl('\nHello **World**\n'),
new CodeChatResponseContentImpl('print("Hello World")', 'python')
]);
});

it('should parse content with markdown, code, and markdown', () => {
const text = 'Hello **World**\n```typescript\nconsole.log("Hello World");\n```\nGoodbye **World**';
const result = parseContents(text);
expect(result).to.deep.equal([
new MarkdownChatResponseContentImpl('Hello **World**\n'),
new CodeChatResponseContentImpl('console.log("Hello World");', 'typescript'),
new MarkdownChatResponseContentImpl('\nGoodbye **World**')
]);
});

it('should handle text with no special content', () => {
const text = 'Just some plain text.';
const result = parseContents(text);
expect(result).to.deep.equal([new MarkdownChatResponseContentImpl('Just some plain text.')]);
});

it('should handle text with only start code block', () => {
const text = '```typescript\nconsole.log("Hello World");';
const result = parseContents(text);
expect(result).to.deep.equal([new MarkdownChatResponseContentImpl('```typescript\nconsole.log("Hello World");')]);
});

it('should handle text with only end code block', () => {
const text = 'console.log("Hello World");\n```';
const result = parseContents(text);
expect(result).to.deep.equal([new MarkdownChatResponseContentImpl('console.log("Hello World");\n```')]);
});

it('should handle text with unmatched code block', () => {
const text = '```typescript\nconsole.log("Hello World");\n```\n```python\nprint("Hello World")';
const result = parseContents(text);
expect(result).to.deep.equal([
new CodeChatResponseContentImpl('console.log("Hello World");', 'typescript'),
new MarkdownChatResponseContentImpl('\n```python\nprint("Hello World")')
]);
});

it('should parse code block without newline after language', () => {
const text = '```typescript console.log("Hello World");```';
const result = parseContents(text);
expect(result).to.deep.equal([
new MarkdownChatResponseContentImpl('```typescript console.log("Hello World");```')
]);
});

it('should parse with matches of multiple different matchers and default', () => {
const text = '<command>\nMY_SPECIAL_COMMAND\n</command>\nHello **World**\n```python\nprint("Hello World")\n```\n<command>\nMY_SPECIAL_COMMAND2\n</command>';
const result = parseContents(text, [CodeContentMatcher, CommandContentMatcher]);
expect(result).to.deep.equal([
new CommandChatResponseContentImpl('MY_SPECIAL_COMMAND'),
new MarkdownChatResponseContentImpl('\nHello **World**\n'),
new CodeChatResponseContentImpl('print("Hello World")', 'python'),
new CommandChatResponseContentImpl('MY_SPECIAL_COMMAND2'),
]);
});
});
Loading
Loading