Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Automatic Import] Prepare to support more connectors #191278

Merged
merged 1 commit into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;

describe('Testing categorization handler', () => {
it('handleCategorization()', async () => {
const response = await handleCategorization(testState, mockLlm);
const response = await handleCategorization({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,18 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import type { CategorizationNodeParams } from './types';
import { combineProcessors } from '../../util/processors';
import { CATEGORIZATION_MAIN_PROMPT } from './prompts';
import { CATEGORIZATION_EXAMPLE_PROCESSORS } from './constants';

export async function handleCategorization(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleCategorization({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationMainPrompt = CATEGORIZATION_MAIN_PROMPT;
const outputParser = new JsonOutputParser();
const categorizationMainGraph = categorizationMainPrompt.pipe(model).pipe(outputParser);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;

describe('Testing categorization handler', () => {
it('handleErrors()', async () => {
const response = await handleErrors(testState, mockLlm);
const response = await handleErrors({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { CategorizationNodeParams } from './types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import { combineProcessors } from '../../util/processors';
import { CATEGORIZATION_ERROR_PROMPT } from './prompts';

export async function handleErrors(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleErrors({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationErrorPrompt = CATEGORIZATION_ERROR_PROMPT;

const outputParser = new JsonOutputParser();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: "I'll callback later.",
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

Expand All @@ -45,7 +45,7 @@ jest.mock('../../util/pipeline', () => ({
}));

describe('runCategorizationGraph', () => {
const mockClient = {
const client = {
asCurrentUser: {
ingest: {
simulate: jest.fn(),
Expand Down Expand Up @@ -131,14 +131,14 @@ describe('runCategorizationGraph', () => {

it('Ensures that the graph compiles', async () => {
try {
await getCategorizationGraph(mockClient, mockLlm);
await getCategorizationGraph({ client, model });
} catch (error) {
// noop
throw Error(`getCategorizationGraph threw an error: ${error}`);
}
});

it('Runs the whole graph, with mocked outputs from the LLM.', async () => {
const categorizationGraph = await getCategorizationGraph(mockClient, mockLlm);
const categorizationGraph = await getCategorizationGraph({ client, model });

(testPipeline as jest.Mock)
.mockResolvedValueOnce(testPipelineValidResult)
Expand All @@ -151,8 +151,8 @@ describe('runCategorizationGraph', () => {
let response;
try {
response = await categorizationGraph.invoke(mockedRequestWithPipeline);
} catch (e) {
// noop
} catch (error) {
throw Error(`getCategorizationGraph threw an error: ${error}`);
}

expect(response.results).toStrictEqual(categorizationExpectedResults);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@
* 2.0.
*/

import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import type { StateGraphArgs } from '@langchain/langgraph';
import { StateGraph, END, START } from '@langchain/langgraph';
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import type { CategorizationState } from '../../types';
import type { CategorizationGraphParams, CategorizationBaseNodeParams } from './types';
import { prefixSamples, formatSamples } from '../../util/samples';
import { handleCategorization } from './categorization';
import { handleValidatePipeline } from '../../util/graph';
Expand Down Expand Up @@ -105,7 +101,7 @@ const graphState: StateGraphArgs<CategorizationState>['channels'] = {
},
};

function modelInput(state: CategorizationState): Partial<CategorizationState> {
function modelInput({ state }: CategorizationBaseNodeParams): Partial<CategorizationState> {
const samples = prefixSamples(state);
const formattedSamples = formatSamples(samples);
const initialPipeline = JSON.parse(JSON.stringify(state.currentPipeline));
Expand All @@ -122,7 +118,7 @@ function modelInput(state: CategorizationState): Partial<CategorizationState> {
};
}

function modelOutput(state: CategorizationState): Partial<CategorizationState> {
function modelOutput({ state }: CategorizationBaseNodeParams): Partial<CategorizationState> {
return {
finalized: true,
lastExecutedChain: 'modelOutput',
Expand All @@ -133,14 +129,14 @@ function modelOutput(state: CategorizationState): Partial<CategorizationState> {
};
}

function validationRouter(state: CategorizationState): string {
function validationRouter({ state }: CategorizationBaseNodeParams): string {
if (Object.keys(state.currentProcessors).length === 0) {
return 'categorization';
}
return 'validateCategorization';
}

function chainRouter(state: CategorizationState): string {
function chainRouter({ state }: CategorizationBaseNodeParams): string {
if (Object.keys(state.errors).length > 0) {
return 'errors';
}
Expand All @@ -157,44 +153,51 @@ function chainRouter(state: CategorizationState): string {
return END;
}

export async function getCategorizationGraph(
client: IScopedClusterClient,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function getCategorizationGraph({ client, model }: CategorizationGraphParams) {
const workflow = new StateGraph({
channels: graphState,
})
.addNode('modelInput', modelInput)
.addNode('modelOutput', modelOutput)
.addNode('modelInput', (state: CategorizationState) => modelInput({ state }))
.addNode('modelOutput', (state: CategorizationState) => modelOutput({ state }))
.addNode('handleCategorization', (state: CategorizationState) =>
handleCategorization(state, model)
handleCategorization({ state, model })
)
.addNode('handleValidatePipeline', (state: CategorizationState) =>
handleValidatePipeline(state, client)
handleValidatePipeline({ state, client })
)
.addNode('handleCategorizationValidation', (state: CategorizationState) =>
handleCategorizationValidation({ state })
)
.addNode('handleCategorizationValidation', handleCategorizationValidation)
.addNode('handleInvalidCategorization', (state: CategorizationState) =>
handleInvalidCategorization(state, model)
handleInvalidCategorization({ state, model })
)
.addNode('handleErrors', (state: CategorizationState) => handleErrors(state, model))
.addNode('handleReview', (state: CategorizationState) => handleReview(state, model))
.addNode('handleErrors', (state: CategorizationState) => handleErrors({ state, model }))
.addNode('handleReview', (state: CategorizationState) => handleReview({ state, model }))
.addEdge(START, 'modelInput')
.addEdge('modelOutput', END)
.addEdge('modelInput', 'handleValidatePipeline')
.addEdge('handleCategorization', 'handleValidatePipeline')
.addEdge('handleInvalidCategorization', 'handleValidatePipeline')
.addEdge('handleErrors', 'handleValidatePipeline')
.addEdge('handleReview', 'handleValidatePipeline')
.addConditionalEdges('handleValidatePipeline', validationRouter, {
categorization: 'handleCategorization',
validateCategorization: 'handleCategorizationValidation',
})
.addConditionalEdges('handleCategorizationValidation', chainRouter, {
modelOutput: 'modelOutput',
errors: 'handleErrors',
invalidCategorization: 'handleInvalidCategorization',
review: 'handleReview',
});
.addConditionalEdges(
'handleValidatePipeline',
(state: CategorizationState) => validationRouter({ state }),
{
categorization: 'handleCategorization',
validateCategorization: 'handleCategorizationValidation',
}
)
.addConditionalEdges(
'handleCategorizationValidation',
(state: CategorizationState) => chainRouter({ state }),
{
modelOutput: 'modelOutput',
errors: 'handleErrors',
invalidCategorization: 'handleInvalidCategorization',
review: 'handleReview',
}
);

const compiledCategorizationGraph = workflow.compile();
return compiledCategorizationGraph;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;

describe('Testing categorization handler', () => {
it('handleInvalidCategorization()', async () => {
const response = await handleInvalidCategorization(testState, mockLlm);
const response = await handleInvalidCategorization({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { CategorizationNodeParams } from './types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import { combineProcessors } from '../../util/processors';
import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants';
import { CATEGORIZATION_VALIDATION_PROMPT } from './prompts';

export async function handleInvalidCategorization(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleInvalidCategorization({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationInvalidPrompt = CATEGORIZATION_VALIDATION_PROMPT;

const outputParser = new JsonOutputParser();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;

describe('Testing categorization handler', () => {
it('handleReview()', async () => {
const response = await handleReview(testState, mockLlm);
const response = await handleReview({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

import { JsonOutputParser } from '@langchain/core/output_parsers';
import { CATEGORIZATION_REVIEW_PROMPT } from './prompts';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { CategorizationNodeParams } from './types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import { combineProcessors } from '../../util/processors';
import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants';

export async function handleReview(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleReview({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationReviewPrompt = CATEGORIZATION_REVIEW_PROMPT;
const outputParser = new JsonOutputParser();
const categorizationReview = categorizationReviewPrompt.pipe(model).pipe(outputParser);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import type { CategorizationState, ChatModels } from '../../types';

export interface CategorizationBaseNodeParams {
state: CategorizationState;
}

export interface CategorizationNodeParams extends CategorizationBaseNodeParams {
model: ChatModels;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not do it other way? Like

export interface BaseNodeParams {
  model: ChatModels;
}

export interface CategorizationNodeParams extends BaseNodeParams {
    state: CategorizationState;
}

And reuse BaseNodeParams in all graphs

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could yeah, however the issue with abstraction is that you would leave it very hard to perform changes on just one graph that might be model related etc.
At that point we start making a type that extends BaseNodeParams only for that graph, and later another one might appear etc.

I don't mind moving BaseNodeParams up to /server/types.ts, and reuse it for all the graphs, but we should be careful to not push the abstractions too wide, as it always makes changes later messier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah... Just seen that state is being used alone and not model... So no point abstracting model.

Probably we leave it this way to be consistent and not make it complex

}

export interface CategorizationGraphParams {
model: ChatModels;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And model can be reused from BaseNodeParams

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as we do not want to separate graph compilation parameters we could reuse that yes, I don't have a strong opinion on this one.

client: IScopedClusterClient;
}
Loading