Skip to content

Commit

Permalink
feat(ai): Make response parsing extensible (#14196)
Browse files Browse the repository at this point in the history
Turns the response parsing method into a more flexible algorithm
that can work with multiple response content matchers. Each response
content matcher has a start and end regexp to define a match, as well
as a `contentFactory` function that turns the matched content into a
`ChatResponseContent` object.

Additionally, the parsing method has a fallback content factory that
will be applied to all unmatched parts, e.g. markdown by default.

Both, the response content matchers and the fallback content factory
and the list of matchers are extensible via DI.

Contributed on behalf of STMicroelectronics.
  • Loading branch information
planger authored Oct 2, 2024
1 parent 915c8ea commit f6ace67
Show file tree
Hide file tree
Showing 6 changed files with 414 additions and 65 deletions.
6 changes: 6 additions & 0 deletions packages/ai-chat/src/browser/ai-chat-frontend-module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import { UniversalChatAgent } from '../common/universal-chat-agent';
import { aiChatPreferences } from './ai-chat-preferences';
import { ChatAgentsVariableContribution } from '../common/chat-agents-variable-contribution';
import { FrontendChatServiceImpl } from './frontend-chat-service';
import { DefaultResponseContentMatcherProvider, DefaultResponseContentFactory, ResponseContentMatcherProvider } from '../common/response-content-matcher';

export default new ContainerModule(bind => {
bindContributionProvider(bind, Agent);
Expand All @@ -42,6 +43,11 @@ export default new ContainerModule(bind => {
bind(ChatAgentService).toService(ChatAgentServiceImpl);
bind(DefaultChatAgentId).toConstantValue({ id: OrchestratorChatAgentId });

bindContributionProvider(bind, ResponseContentMatcherProvider);
bind(DefaultResponseContentMatcherProvider).toSelf().inSingletonScope();
bind(ResponseContentMatcherProvider).toService(DefaultResponseContentMatcherProvider);
bind(DefaultResponseContentFactory).toSelf().inSingletonScope();

bind(AIVariableContribution).to(ChatAgentsVariableContribution).inSingletonScope();

bind(ChatRequestParserImpl).toSelf().inSingletonScope();
Expand Down
119 changes: 59 additions & 60 deletions packages/ai-chat/src/common/chat-agents.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import {
LanguageModel,
LanguageModelRequirement,
LanguageModelResponse,
LanguageModelStreamResponse,
PromptService,
ResolvedPromptTemplate,
ToolRequest,
Expand All @@ -37,19 +38,20 @@ import {
LanguageModelStreamResponsePart,
MessageActor,
} from '@theia/ai-core/lib/common';
import { CancellationToken, CancellationTokenSource, ILogger, isArray } from '@theia/core';
import { inject, injectable } from '@theia/core/shared/inversify';
import { CancellationToken, CancellationTokenSource, ContributionProvider, ILogger, isArray } from '@theia/core';
import { inject, injectable, named, postConstruct } from '@theia/core/shared/inversify';
import { ChatAgentService } from './chat-agent-service';
import {
ChatModel,
ChatRequestModel,
ChatRequestModelImpl,
ChatResponseContent,
CodeChatResponseContentImpl,
ErrorChatResponseContentImpl,
MarkdownChatResponseContentImpl,
ToolCallChatResponseContentImpl
} from './chat-model';
import { findFirstMatch, parseContents } from './parse-contents';
import { DefaultResponseContentFactory, ResponseContentMatcher, ResponseContentMatcherProvider } from './response-content-matcher';

/**
* A conversation consists of a sequence of ChatMessages.
Expand Down Expand Up @@ -121,6 +123,14 @@ export abstract class AbstractChatAgent {
@inject(ILogger) protected logger: ILogger;
@inject(CommunicationRecordingService) protected recordingService: CommunicationRecordingService;
@inject(PromptService) protected promptService: PromptService;

@inject(ContributionProvider) @named(ResponseContentMatcherProvider)
protected contentMatcherProviders: ContributionProvider<ResponseContentMatcherProvider>;
protected contentMatchers: ResponseContentMatcher[] = [];

@inject(DefaultResponseContentFactory)
protected defaultContentFactory: DefaultResponseContentFactory;

constructor(
public id: string,
public languageModelRequirements: LanguageModelRequirement[],
Expand All @@ -130,6 +140,11 @@ export abstract class AbstractChatAgent {
public tags: String[] = ['Chat']) {
}

@postConstruct()
init(): void {
this.contentMatchers = this.contentMatcherProviders.getContributions().flatMap(provider => provider.matchers);
}

async invoke(request: ChatRequestModelImpl): Promise<void> {
try {
const languageModel = await this.getLanguageModel(this.defaultLanguageModelPurpose);
Expand Down Expand Up @@ -189,6 +204,14 @@ export abstract class AbstractChatAgent {
}
}

protected parseContents(text: string): ChatResponseContent[] {
return parseContents(
text,
this.contentMatchers,
this.defaultContentFactory?.create.bind(this.defaultContentFactory)
);
};

protected handleError(request: ChatRequestModelImpl, error: Error): void {
request.response.response.addContent(new ErrorChatResponseContentImpl(error));
request.response.error(error);
Expand Down Expand Up @@ -281,9 +304,8 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent {

protected override async addContentsToResponse(languageModelResponse: LanguageModelResponse, request: ChatRequestModelImpl): Promise<void> {
if (isLanguageModelTextResponse(languageModelResponse)) {
request.response.response.addContent(
new MarkdownChatResponseContentImpl(languageModelResponse.text)
);
const contents = this.parseContents(languageModelResponse.text);
request.response.response.addContents(contents);
request.response.complete();
this.recordingService.recordResponse({
agentId: this.id,
Expand All @@ -295,57 +317,7 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent {
return;
}
if (isLanguageModelStreamResponse(languageModelResponse)) {
for await (const token of languageModelResponse.stream) {
const newContents = this.parse(token, request.response.response.content);
if (isArray(newContents)) {
newContents.forEach(newContent => request.response.response.addContent(newContent));
} else {
request.response.response.addContent(newContents);
}

const lastContent = request.response.response.content.pop();
if (lastContent === undefined) {
return;
}
const text = lastContent.asString?.();
if (text === undefined) {
return;
}
let curSearchIndex = 0;
const result: ChatResponseContent[] = [];
while (curSearchIndex < text.length) {
// find start of code block: ```[language]\n<code>[\n]```
const codeStartIndex = text.indexOf('```', curSearchIndex);
if (codeStartIndex === -1) {
break;
}

// find language specifier if present
const newLineIndex = text.indexOf('\n', codeStartIndex + 3);
const language = codeStartIndex + 3 < newLineIndex ? text.substring(codeStartIndex + 3, newLineIndex) : undefined;

// find end of code block
const codeEndIndex = text.indexOf('```', codeStartIndex + 3);
if (codeEndIndex === -1) {
break;
}

// add text before code block as markdown content
result.push(new MarkdownChatResponseContentImpl(text.substring(curSearchIndex, codeStartIndex)));
// add code block as code content
const codeText = text.substring(newLineIndex + 1, codeEndIndex).trimEnd();
result.push(new CodeChatResponseContentImpl(codeText, language));
curSearchIndex = codeEndIndex + 3;
}

if (result.length > 0) {
result.forEach(r => {
request.response.response.addContent(r);
});
} else {
request.response.response.addContent(lastContent);
}
}
await this.addStreamResponse(languageModelResponse, request);
request.response.complete();
this.recordingService.recordResponse({
agentId: this.id,
Expand All @@ -366,19 +338,46 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent {
);
}

private parse(token: LanguageModelStreamResponsePart, previousContent: ChatResponseContent[]): ChatResponseContent | ChatResponseContent[] {
protected async addStreamResponse(languageModelResponse: LanguageModelStreamResponse, request: ChatRequestModelImpl): Promise<void> {
for await (const token of languageModelResponse.stream) {
const newContents = this.parse(token, request.response.response.content);
if (isArray(newContents)) {
request.response.response.addContents(newContents);
} else {
request.response.response.addContent(newContents);
}

const lastContent = request.response.response.content.pop();
if (lastContent === undefined) {
return;
}
const text = lastContent.asString?.();
if (text === undefined) {
return;
}

const result: ChatResponseContent[] = findFirstMatch(this.contentMatchers, text) ? this.parseContents(text) : [];
if (result.length > 0) {
request.response.response.addContents(result);
} else {
request.response.response.addContent(lastContent);
}
}
}

protected parse(token: LanguageModelStreamResponsePart, previousContent: ChatResponseContent[]): ChatResponseContent | ChatResponseContent[] {
const content = token.content;
// eslint-disable-next-line no-null/no-null
if (content !== undefined && content !== null) {
return new MarkdownChatResponseContentImpl(content);
return this.defaultContentFactory.create(content);
}
const toolCalls = token.tool_calls;
if (toolCalls !== undefined) {
const toolCallContents = toolCalls.map(toolCall =>
new ToolCallChatResponseContentImpl(toolCall.id, toolCall.function?.name, toolCall.function?.arguments, toolCall.finished, toolCall.result));
return toolCallContents;
}
return new MarkdownChatResponseContentImpl('');
return this.defaultContentFactory.create('');
}

}
18 changes: 13 additions & 5 deletions packages/ai-chat/src/common/chat-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -601,10 +601,20 @@ class ChatResponseImpl implements ChatResponse {
return this._content;
}

addContents(contents: ChatResponseContent[]): void {
contents.forEach(c => this.doAddContent(c));
this._onDidChangeEmitter.fire();
}

addContent(nextContent: ChatResponseContent): void {
// TODO: Support more complex merges affecting different content than the last, e.g. via some kind of ProcessorRegistry
// TODO: Support more of the built-in VS Code behavior, see
// https://github.com/microsoft/vscode/blob/a2cab7255c0df424027be05d58e1b7b941f4ea60/src/vs/workbench/contrib/chat/common/chatModel.ts#L188-L244
this.doAddContent(nextContent);
this._onDidChangeEmitter.fire();
}

protected doAddContent(nextContent: ChatResponseContent): void {
if (ToolCallChatResponseContent.is(nextContent) && nextContent.id !== undefined) {
const fittingTool = this._content.find(c => ToolCallChatResponseContent.is(c) && c.id === nextContent.id);
if (fittingTool !== undefined) {
Expand All @@ -613,10 +623,9 @@ class ChatResponseImpl implements ChatResponse {
this._content.push(nextContent);
}
} else {
const lastElement =
this._content.length > 0
? this._content[this._content.length - 1]
: undefined;
const lastElement = this._content.length > 0
? this._content[this._content.length - 1]
: undefined;
if (lastElement?.kind === nextContent.kind && ChatResponseContent.hasMerge(lastElement)) {
const mergeSuccess = lastElement.merge(nextContent);
if (!mergeSuccess) {
Expand All @@ -627,7 +636,6 @@ class ChatResponseImpl implements ChatResponse {
}
}
this._updateResponseRepresentation();
this._onDidChangeEmitter.fire();
}

protected _updateResponseRepresentation(): void {
Expand Down
Loading

0 comments on commit f6ace67

Please sign in to comment.