Skip to content

Commit

Permalink
[Obs AI Assistant] Update the simulate function calling setting to su…
Browse files Browse the repository at this point in the history
…pport "auto" (elastic#209628)

Closes elastic/obs-ai-assistant-team#198

## Summary

The simulated function calling setting is currently a boolean. It needs
to be updated to support the option `auto`.
`export type FunctionCallingMode = 'native' | 'simulated' | 'auto';`

If the setting is set to `false`, `auto` will be passed to the inference
client. If the setting is `true`, `simulated` will be passed to it.

Relates to elastic#208144


### Checklist

- [x] Any text added follows [EUI's writing
guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses
sentence case text and includes [i18n
support](https://github.com/elastic/kibana/blob/main/src/platform/packages/shared/kbn-i18n/README.md)
- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
- [x] The PR description includes the appropriate Release Notes section,
and the correct `release_note:*` label is applied per the
[guidelines](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process)

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
  • Loading branch information
viduni94 and kibanamachine authored Feb 7, 2025
1 parent 77ea8fe commit 343b80a
Show file tree
Hide file tree
Showing 11 changed files with 38 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ import {
VisualizeESQLUserIntention,
type ChatActionClickPayload,
type Feedback,
aiAssistantSimulatedFunctionCalling,
} from '@kbn/observability-ai-assistant-plugin/public';
import type { AuthenticatedUser } from '@kbn/security-plugin/common';
import { findLastIndex } from 'lodash';
import React, { useCallback, useEffect, useRef, useState } from 'react';
import type { UseKnowledgeBaseResult } from '../hooks/use_knowledge_base';
import { ASSISTANT_SETUP_TITLE, EMPTY_CONVERSATION_TITLE, UPGRADE_LICENSE_TITLE } from '../i18n';
import { useAIAssistantChatService } from '../hooks/use_ai_assistant_chat_service';
import { useSimulatedFunctionCalling } from '../hooks/use_simulated_function_calling';
import { useGenAIConnectors } from '../hooks/use_genai_connectors';
import { useConversation } from '../hooks/use_conversation';
import { FlyoutPositionMode } from './chat_flyout';
Expand All @@ -47,6 +47,7 @@ import { WelcomeMessage } from './welcome_message';
import { useLicense } from '../hooks/use_license';
import { PromptEditor } from '../prompt_editor/prompt_editor';
import { deserializeMessage } from '../utils/deserialize_message';
import { useKibana } from '../hooks/use_kibana';

const fullHeightClassName = css`
height: 100%;
Expand Down Expand Up @@ -138,7 +139,14 @@ export function ChatBody({

const chatService = useAIAssistantChatService();

const { simulatedFunctionCallingEnabled } = useSimulatedFunctionCalling();
const {
services: { uiSettings },
} = useKibana();

const simulateFunctionCalling = uiSettings!.get<boolean>(
aiAssistantSimulatedFunctionCalling,
false
);

const { conversation, messages, next, state, stop, saveTitle } = useConversation({
initialConversationId,
Expand Down Expand Up @@ -409,7 +417,7 @@ export function ChatBody({
</div>
</EuiFlexItem>

{simulatedFunctionCallingEnabled ? (
{simulateFunctionCalling ? (
<EuiFlexItem grow={false}>
<SimulatedFunctionCallingCallout />
</EuiFlexItem>
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ describe('chatFunctionClient', () => {
messages: [],
signal: new AbortController().signal,
connectorId: 'foo',
useSimulatedFunctionCalling: false,
simulateFunctionCalling: false,
});
}).rejects.toThrowError(`Function arguments are invalid`);

Expand Down Expand Up @@ -112,7 +112,7 @@ describe('chatFunctionClient', () => {
messages: [],
signal: new AbortController().signal,
connectorId: 'foo',
useSimulatedFunctionCalling: false,
simulateFunctionCalling: false,
});

expect(result).toEqual({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,15 @@ export class ChatFunctionClient {
messages,
signal,
connectorId,
useSimulatedFunctionCalling,
simulateFunctionCalling,
}: {
chat: FunctionCallChatFunction;
name: string;
args: string | undefined;
messages: Message[];
signal: AbortSignal;
connectorId: string;
useSimulatedFunctionCalling: boolean;
simulateFunctionCalling: boolean;
}): Promise<FunctionResponse> {
const fn = this.functionRegistry.get(name);

Expand All @@ -194,7 +194,7 @@ export class ChatFunctionClient {
screenContexts: this.screenContexts,
chat,
connectorId,
useSimulatedFunctionCalling,
simulateFunctionCalling,
},
signal
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ describe('Observability AI Assistant client', () => {
expect.objectContaining({
connectorId: 'foo',
stream: false,
functionCalling: 'native',
functionCalling: 'auto',
toolChoice: expect.objectContaining({
function: 'title_conversation',
}),
Expand Down Expand Up @@ -349,7 +349,7 @@ describe('Observability AI Assistant client', () => {
messages: expect.arrayContaining([
{ role: 'user', content: 'How many alerts do I have?' },
]),
functionCalling: 'native',
functionCalling: 'auto',
toolChoice: undefined,
tools: undefined,
},
Expand Down Expand Up @@ -872,7 +872,7 @@ describe('Observability AI Assistant client', () => {
},
},
],
useSimulatedFunctionCalling: false,
simulateFunctionCalling: false,
});
});

Expand Down Expand Up @@ -919,7 +919,7 @@ describe('Observability AI Assistant client', () => {
messages: expect.arrayContaining([
{ role: 'user', content: 'How many alerts do I have?' },
]),
functionCalling: 'native',
functionCalling: 'auto',
toolChoice: 'auto',
tools: expect.any(Object),
},
Expand Down Expand Up @@ -1080,7 +1080,7 @@ describe('Observability AI Assistant client', () => {
messages: expect.arrayContaining([
{ role: 'user', content: 'How many alerts do I have?' },
]),
functionCalling: 'native',
functionCalling: 'auto',
toolChoice: 'auto',
tools: expect.any(Object),
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ export class ObservabilityAIAssistantClient {
disableFunctions,
tracer: completeTracer,
connectorId,
useSimulatedFunctionCalling: simulateFunctionCalling === true,
simulateFunctionCalling,
})
);
}),
Expand Down Expand Up @@ -505,15 +505,17 @@ export class ObservabilityAIAssistantClient {
}
: ToolChoiceType.auto;
}

const options = {
connectorId,
messages: convertMessagesForInference(
messages.filter((message) => message.message.role !== MessageRole.System)
),
toolChoice,
tools,
functionCalling: (simulateFunctionCalling ? 'simulated' : 'native') as FunctionCallingMode,
functionCalling: (simulateFunctionCalling ? 'simulated' : 'auto') as FunctionCallingMode,
};

if (stream) {
return defer(() =>
this.dependencies.inferenceClient.chatComplete({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ function executeFunctionAndCatchError({
logger,
tracer,
connectorId,
useSimulatedFunctionCalling,
simulateFunctionCalling,
}: {
name: string;
args: string | undefined;
Expand All @@ -65,7 +65,7 @@ function executeFunctionAndCatchError({
logger: Logger;
tracer: LangTracer;
connectorId: string;
useSimulatedFunctionCalling: boolean;
simulateFunctionCalling: boolean;
}): Observable<MessageOrChatEvent> {
// hide token count events from functions to prevent them from
// having to deal with it as well
Expand All @@ -86,7 +86,7 @@ function executeFunctionAndCatchError({
signal,
messages,
connectorId,
useSimulatedFunctionCalling,
simulateFunctionCalling,
})
);

Expand Down Expand Up @@ -184,7 +184,7 @@ export function continueConversation({
disableFunctions,
tracer,
connectorId,
useSimulatedFunctionCalling,
simulateFunctionCalling,
}: {
messages: Message[];
functionClient: ChatFunctionClient;
Expand All @@ -201,7 +201,7 @@ export function continueConversation({
};
tracer: LangTracer;
connectorId: string;
useSimulatedFunctionCalling: boolean;
simulateFunctionCalling: boolean;
}): Observable<MessageOrChatEvent> {
let nextFunctionCallsLeft = functionCallsLeft;

Expand Down Expand Up @@ -319,7 +319,7 @@ export function continueConversation({
logger,
tracer,
connectorId,
useSimulatedFunctionCalling,
simulateFunctionCalling,
});
}

Expand Down Expand Up @@ -348,7 +348,7 @@ export function continueConversation({
disableFunctions,
tracer,
connectorId,
useSimulatedFunctionCalling,
simulateFunctionCalling,
});
})
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ type RespondFunction<TArguments, TResponse extends FunctionResponse> = (
screenContexts: ObservabilityAIAssistantScreenContextRequest[];
chat: FunctionCallChatFunction;
connectorId: string;
useSimulatedFunctionCalling: boolean;
simulateFunctionCalling: boolean;
},
signal: AbortSignal
) => Promise<TResponse>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ export async function registerDocumentationFunction({
required: ['query'],
} as const,
},
async ({ arguments: { query, product }, connectorId, useSimulatedFunctionCalling }) => {
async ({ arguments: { query, product }, connectorId, simulateFunctionCalling }) => {
const response = await llmTasks!.retrieveDocumentation({
searchTerm: query,
products: product ? [product] : undefined,
max: 3,
connectorId,
request: resources.request,
functionCalling: useSimulatedFunctionCalling ? 'simulated' : 'native',
functionCalling: simulateFunctionCalling ? 'simulated' : 'auto',
});

return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ export function registerQueryFunction({
function takes no input.`,
visibility: FunctionVisibility.AssistantOnly,
},
async ({ messages, connectorId, useSimulatedFunctionCalling }, signal) => {
async ({ messages, connectorId, simulateFunctionCalling }, signal) => {
const esqlFunctions = functions
.getFunctions()
.filter(
Expand All @@ -137,7 +137,7 @@ export function registerQueryFunction({
{ description: fn.description, schema: fn.parameters } as ToolDefinition,
])
),
functionCalling: useSimulatedFunctionCalling ? 'simulated' : 'native',
functionCalling: simulateFunctionCalling ? 'simulated' : 'auto',
});

const chatMessageId = v4();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export const uiSettings: Record<string, UiSettingsParams> = {
'xpack.observabilityAiAssistantManagement.settingsPage.simulatedFunctionCallingDescription',
{
defaultMessage:
'<em>[technical preview]</em> Use simulated function calling. Simulated function calling does not need API support for functions or tools, but it may decrease performance. Simulated function calling is currently always enabled for non-OpenAI connector, regardless of this setting.',
'<em>[technical preview]</em> Simulated function calling does not need API support for functions or tools, but it may decrease performance. It is currently always enabled for connectors that do not have API support for Native function calling, regardless of this setting.',
values: {
em: (chunks) => `<em>${chunks}</em>`,
},
Expand Down

0 comments on commit 343b80a

Please sign in to comment.