Skip to content

Commit

Permalink
[Automatic Import] Prepare to support more connectors (elastic#191278)
Browse files Browse the repository at this point in the history
This PR does not add any functionality, it adds interfaces to the
expected parameters from get*Graph and its graph nodes.
This is so it will be much easier extend this later when we might need
to add/switch types over a whole graph like we would have needed when
adding more connectors.

The PR touches a lot of files, but does not add/remove/change any
functionality at all, and the current expected function arguments are
the same, just the format is a bit different to better align with how
other plugins are doing it.

- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios.

(cherry picked from commit 791f638)
  • Loading branch information
P1llus committed Aug 27, 2024
1 parent 98eb41c commit 7c77b5a
Show file tree
Hide file tree
Showing 41 changed files with 307 additions and 245 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;

describe('Testing categorization handler', () => {
it('handleCategorization()', async () => {
const response = await handleCategorization(testState, mockLlm);
const response = await handleCategorization({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,18 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import type { CategorizationNodeParams } from './types';
import { combineProcessors } from '../../util/processors';
import { CATEGORIZATION_MAIN_PROMPT } from './prompts';
import { CATEGORIZATION_EXAMPLE_PROCESSORS } from './constants';

export async function handleCategorization(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleCategorization({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationMainPrompt = CATEGORIZATION_MAIN_PROMPT;
const outputParser = new JsonOutputParser();
const categorizationMainGraph = categorizationMainPrompt.pipe(model).pipe(outputParser);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;

describe('Testing categorization handler', () => {
it('handleErrors()', async () => {
const response = await handleErrors(testState, mockLlm);
const response = await handleErrors({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { CategorizationNodeParams } from './types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import { combineProcessors } from '../../util/processors';
import { CATEGORIZATION_ERROR_PROMPT } from './prompts';

export async function handleErrors(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleErrors({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationErrorPrompt = CATEGORIZATION_ERROR_PROMPT;

const outputParser = new JsonOutputParser();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: "I'll callback later.",
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

Expand All @@ -45,7 +45,7 @@ jest.mock('../../util/pipeline', () => ({
}));

describe('runCategorizationGraph', () => {
const mockClient = {
const client = {
asCurrentUser: {
ingest: {
simulate: jest.fn(),
Expand Down Expand Up @@ -131,14 +131,14 @@ describe('runCategorizationGraph', () => {

it('Ensures that the graph compiles', async () => {
try {
await getCategorizationGraph(mockClient, mockLlm);
await getCategorizationGraph({ client, model });
} catch (error) {
// noop
throw Error(`getCategorizationGraph threw an error: ${error}`);
}
});

it('Runs the whole graph, with mocked outputs from the LLM.', async () => {
const categorizationGraph = await getCategorizationGraph(mockClient, mockLlm);
const categorizationGraph = await getCategorizationGraph({ client, model });

(testPipeline as jest.Mock)
.mockResolvedValueOnce(testPipelineValidResult)
Expand All @@ -151,8 +151,8 @@ describe('runCategorizationGraph', () => {
let response;
try {
response = await categorizationGraph.invoke(mockedRequestWithPipeline);
} catch (e) {
// noop
} catch (error) {
throw Error(`getCategorizationGraph threw an error: ${error}`);
}

expect(response.results).toStrictEqual(categorizationExpectedResults);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@
* 2.0.
*/

import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import type { StateGraphArgs } from '@langchain/langgraph';
import { StateGraph, END, START } from '@langchain/langgraph';
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import type { CategorizationState } from '../../types';
import type { CategorizationGraphParams, CategorizationBaseNodeParams } from './types';
import { prefixSamples, formatSamples } from '../../util/samples';
import { handleCategorization } from './categorization';
import { handleValidatePipeline } from '../../util/graph';
Expand Down Expand Up @@ -105,7 +101,7 @@ const graphState: StateGraphArgs<CategorizationState>['channels'] = {
},
};

function modelInput(state: CategorizationState): Partial<CategorizationState> {
function modelInput({ state }: CategorizationBaseNodeParams): Partial<CategorizationState> {
const samples = prefixSamples(state);
const formattedSamples = formatSamples(samples);
const initialPipeline = JSON.parse(JSON.stringify(state.currentPipeline));
Expand All @@ -122,7 +118,7 @@ function modelInput(state: CategorizationState): Partial<CategorizationState> {
};
}

function modelOutput(state: CategorizationState): Partial<CategorizationState> {
function modelOutput({ state }: CategorizationBaseNodeParams): Partial<CategorizationState> {
return {
finalized: true,
lastExecutedChain: 'modelOutput',
Expand All @@ -133,14 +129,14 @@ function modelOutput(state: CategorizationState): Partial<CategorizationState> {
};
}

function validationRouter(state: CategorizationState): string {
function validationRouter({ state }: CategorizationBaseNodeParams): string {
if (Object.keys(state.currentProcessors).length === 0) {
return 'categorization';
}
return 'validateCategorization';
}

function chainRouter(state: CategorizationState): string {
function chainRouter({ state }: CategorizationBaseNodeParams): string {
if (Object.keys(state.errors).length > 0) {
return 'errors';
}
Expand All @@ -157,44 +153,51 @@ function chainRouter(state: CategorizationState): string {
return END;
}

export async function getCategorizationGraph(
client: IScopedClusterClient,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function getCategorizationGraph({ client, model }: CategorizationGraphParams) {
const workflow = new StateGraph({
channels: graphState,
})
.addNode('modelInput', modelInput)
.addNode('modelOutput', modelOutput)
.addNode('modelInput', (state: CategorizationState) => modelInput({ state }))
.addNode('modelOutput', (state: CategorizationState) => modelOutput({ state }))
.addNode('handleCategorization', (state: CategorizationState) =>
handleCategorization(state, model)
handleCategorization({ state, model })
)
.addNode('handleValidatePipeline', (state: CategorizationState) =>
handleValidatePipeline(state, client)
handleValidatePipeline({ state, client })
)
.addNode('handleCategorizationValidation', (state: CategorizationState) =>
handleCategorizationValidation({ state })
)
.addNode('handleCategorizationValidation', handleCategorizationValidation)
.addNode('handleInvalidCategorization', (state: CategorizationState) =>
handleInvalidCategorization(state, model)
handleInvalidCategorization({ state, model })
)
.addNode('handleErrors', (state: CategorizationState) => handleErrors(state, model))
.addNode('handleReview', (state: CategorizationState) => handleReview(state, model))
.addNode('handleErrors', (state: CategorizationState) => handleErrors({ state, model }))
.addNode('handleReview', (state: CategorizationState) => handleReview({ state, model }))
.addEdge(START, 'modelInput')
.addEdge('modelOutput', END)
.addEdge('modelInput', 'handleValidatePipeline')
.addEdge('handleCategorization', 'handleValidatePipeline')
.addEdge('handleInvalidCategorization', 'handleValidatePipeline')
.addEdge('handleErrors', 'handleValidatePipeline')
.addEdge('handleReview', 'handleValidatePipeline')
.addConditionalEdges('handleValidatePipeline', validationRouter, {
categorization: 'handleCategorization',
validateCategorization: 'handleCategorizationValidation',
})
.addConditionalEdges('handleCategorizationValidation', chainRouter, {
modelOutput: 'modelOutput',
errors: 'handleErrors',
invalidCategorization: 'handleInvalidCategorization',
review: 'handleReview',
});
.addConditionalEdges(
'handleValidatePipeline',
(state: CategorizationState) => validationRouter({ state }),
{
categorization: 'handleCategorization',
validateCategorization: 'handleCategorizationValidation',
}
)
.addConditionalEdges(
'handleCategorizationValidation',
(state: CategorizationState) => chainRouter({ state }),
{
modelOutput: 'modelOutput',
errors: 'handleErrors',
invalidCategorization: 'handleInvalidCategorization',
review: 'handleReview',
}
);

const compiledCategorizationGraph = workflow.compile();
return compiledCategorizationGraph;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;

describe('Testing categorization handler', () => {
it('handleInvalidCategorization()', async () => {
const response = await handleInvalidCategorization(testState, mockLlm);
const response = await handleInvalidCategorization({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { CategorizationNodeParams } from './types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import { combineProcessors } from '../../util/processors';
import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants';
import { CATEGORIZATION_VALIDATION_PROMPT } from './prompts';

export async function handleInvalidCategorization(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleInvalidCategorization({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationInvalidPrompt = CATEGORIZATION_VALIDATION_PROMPT;

const outputParser = new JsonOutputParser();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;

describe('Testing categorization handler', () => {
it('handleReview()', async () => {
const response = await handleReview(testState, mockLlm);
const response = await handleReview({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

import { JsonOutputParser } from '@langchain/core/output_parsers';
import { CATEGORIZATION_REVIEW_PROMPT } from './prompts';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { CategorizationNodeParams } from './types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import { combineProcessors } from '../../util/processors';
import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants';

export async function handleReview(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleReview({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationReviewPrompt = CATEGORIZATION_REVIEW_PROMPT;
const outputParser = new JsonOutputParser();
const categorizationReview = categorizationReviewPrompt.pipe(model).pipe(outputParser);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import type { CategorizationState, ChatModels } from '../../types';

export interface CategorizationBaseNodeParams {
state: CategorizationState;
}

export interface CategorizationNodeParams extends CategorizationBaseNodeParams {
model: ChatModels;
}

export interface CategorizationGraphParams {
model: ChatModels;
client: IScopedClusterClient;
}
Loading

0 comments on commit 7c77b5a

Please sign in to comment.