Skip to content

Commit

Permalink
[8.15] [Automatic Import] Prepare to support more connectors (#191278) (
Browse files Browse the repository at this point in the history
#191525)

# Backport

This will backport the following commits from `main` to `8.15`:
- [[Automatic Import] Prepare to support more connectors
(#191278)](#191278)

<!--- Backport version: 9.4.3 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Marius
Iversen","email":"marius.iversen@elastic.co"},"sourceCommit":{"committedDate":"2024-08-26T17:29:30Z","message":"[Automatic
Import] Prepare to support more connectors (#191278)\n\nThis PR does not
add any functionality, it adds interfaces to the\nexpected parameters
from get*Graph and its graph nodes.\nThis is so it will be much easier
extend this later when we might need\nto add/switch types over a whole
graph like we would have needed when\nadding more connectors.\n\nThe PR
touches a lot of files, but does not add/remove/change
any\nfunctionality at all, and the current expected function arguments
are\nthe same, just the format is a bit different to better align with
how\nother plugins are doing it.\n\n\n\n- [x] [Unit or
functional\ntests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)\nwere
updated or added to match the most common
scenarios.","sha":"791f638823820e8c8cc5af31e56a4f0875356dc9","branchLabelMapping":{"^v8.16.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["enhancement","release_note:skip","backport:prev-minor","v8.16.0","Team:Security-Scalability","v8.15.1"],"title":"[Automatic
Import] Prepare to support more
connectors","number":191278,"url":"https://github.com/elastic/kibana/pull/191278","mergeCommit":{"message":"[Automatic
Import] Prepare to support more connectors (#191278)\n\nThis PR does not
add any functionality, it adds interfaces to the\nexpected parameters
from get*Graph and its graph nodes.\nThis is so it will be much easier
extend this later when we might need\nto add/switch types over a whole
graph like we would have needed when\nadding more connectors.\n\nThe PR
touches a lot of files, but does not add/remove/change
any\nfunctionality at all, and the current expected function arguments
are\nthe same, just the format is a bit different to better align with
how\nother plugins are doing it.\n\n\n\n- [x] [Unit or
functional\ntests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)\nwere
updated or added to match the most common
scenarios.","sha":"791f638823820e8c8cc5af31e56a4f0875356dc9"}},"sourceBranch":"main","suggestedTargetBranches":["8.15"],"targetPullRequestStates":[{"branch":"main","label":"v8.16.0","branchLabelMappingKey":"^v8.16.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/191278","number":191278,"mergeCommit":{"message":"[Automatic
Import] Prepare to support more connectors (#191278)\n\nThis PR does not
add any functionality, it adds interfaces to the\nexpected parameters
from get*Graph and its graph nodes.\nThis is so it will be much easier
extend this later when we might need\nto add/switch types over a whole
graph like we would have needed when\nadding more connectors.\n\nThe PR
touches a lot of files, but does not add/remove/change
any\nfunctionality at all, and the current expected function arguments
are\nthe same, just the format is a bit different to better align with
how\nother plugins are doing it.\n\n\n\n- [x] [Unit or
functional\ntests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)\nwere
updated or added to match the most common
scenarios.","sha":"791f638823820e8c8cc5af31e56a4f0875356dc9"}},{"branch":"8.15","label":"v8.15.1","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->

Co-authored-by: Marius Iversen <marius.iversen@elastic.co>
  • Loading branch information
kibanamachine and P1llus committed Aug 27, 2024
1 parent 98eb41c commit 8c908bc
Show file tree
Hide file tree
Showing 41 changed files with 307 additions and 245 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;

describe('Testing categorization handler', () => {
it('handleCategorization()', async () => {
const response = await handleCategorization(testState, mockLlm);
const response = await handleCategorization({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,18 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import type { CategorizationNodeParams } from './types';
import { combineProcessors } from '../../util/processors';
import { CATEGORIZATION_MAIN_PROMPT } from './prompts';
import { CATEGORIZATION_EXAMPLE_PROCESSORS } from './constants';

export async function handleCategorization(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleCategorization({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationMainPrompt = CATEGORIZATION_MAIN_PROMPT;
const outputParser = new JsonOutputParser();
const categorizationMainGraph = categorizationMainPrompt.pipe(model).pipe(outputParser);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;

describe('Testing categorization handler', () => {
it('handleErrors()', async () => {
const response = await handleErrors(testState, mockLlm);
const response = await handleErrors({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,18 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { CategorizationNodeParams } from './types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import { combineProcessors } from '../../util/processors';
import { CATEGORIZATION_ERROR_PROMPT } from './prompts';

export async function handleErrors(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleErrors({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationErrorPrompt = CATEGORIZATION_ERROR_PROMPT;

const outputParser = new JsonOutputParser();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: "I'll callback later.",
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

Expand All @@ -45,7 +45,7 @@ jest.mock('../../util/pipeline', () => ({
}));

describe('runCategorizationGraph', () => {
const mockClient = {
const client = {
asCurrentUser: {
ingest: {
simulate: jest.fn(),
Expand Down Expand Up @@ -131,14 +131,14 @@ describe('runCategorizationGraph', () => {

it('Ensures that the graph compiles', async () => {
try {
await getCategorizationGraph(mockClient, mockLlm);
await getCategorizationGraph({ client, model });
} catch (error) {
// noop
throw Error(`getCategorizationGraph threw an error: ${error}`);
}
});

it('Runs the whole graph, with mocked outputs from the LLM.', async () => {
const categorizationGraph = await getCategorizationGraph(mockClient, mockLlm);
const categorizationGraph = await getCategorizationGraph({ client, model });

(testPipeline as jest.Mock)
.mockResolvedValueOnce(testPipelineValidResult)
Expand All @@ -151,8 +151,8 @@ describe('runCategorizationGraph', () => {
let response;
try {
response = await categorizationGraph.invoke(mockedRequestWithPipeline);
} catch (e) {
// noop
} catch (error) {
throw Error(`getCategorizationGraph threw an error: ${error}`);
}

expect(response.results).toStrictEqual(categorizationExpectedResults);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@
* 2.0.
*/

import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import type { StateGraphArgs } from '@langchain/langgraph';
import { StateGraph, END, START } from '@langchain/langgraph';
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import type { CategorizationState } from '../../types';
import type { CategorizationGraphParams, CategorizationBaseNodeParams } from './types';
import { prefixSamples, formatSamples } from '../../util/samples';
import { handleCategorization } from './categorization';
import { handleValidatePipeline } from '../../util/graph';
Expand Down Expand Up @@ -105,7 +101,7 @@ const graphState: StateGraphArgs<CategorizationState>['channels'] = {
},
};

function modelInput(state: CategorizationState): Partial<CategorizationState> {
function modelInput({ state }: CategorizationBaseNodeParams): Partial<CategorizationState> {
const samples = prefixSamples(state);
const formattedSamples = formatSamples(samples);
const initialPipeline = JSON.parse(JSON.stringify(state.currentPipeline));
Expand All @@ -122,7 +118,7 @@ function modelInput(state: CategorizationState): Partial<CategorizationState> {
};
}

function modelOutput(state: CategorizationState): Partial<CategorizationState> {
function modelOutput({ state }: CategorizationBaseNodeParams): Partial<CategorizationState> {
return {
finalized: true,
lastExecutedChain: 'modelOutput',
Expand All @@ -133,14 +129,14 @@ function modelOutput(state: CategorizationState): Partial<CategorizationState> {
};
}

function validationRouter(state: CategorizationState): string {
function validationRouter({ state }: CategorizationBaseNodeParams): string {
if (Object.keys(state.currentProcessors).length === 0) {
return 'categorization';
}
return 'validateCategorization';
}

function chainRouter(state: CategorizationState): string {
function chainRouter({ state }: CategorizationBaseNodeParams): string {
if (Object.keys(state.errors).length > 0) {
return 'errors';
}
Expand All @@ -157,44 +153,51 @@ function chainRouter(state: CategorizationState): string {
return END;
}

export async function getCategorizationGraph(
client: IScopedClusterClient,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function getCategorizationGraph({ client, model }: CategorizationGraphParams) {
const workflow = new StateGraph({
channels: graphState,
})
.addNode('modelInput', modelInput)
.addNode('modelOutput', modelOutput)
.addNode('modelInput', (state: CategorizationState) => modelInput({ state }))
.addNode('modelOutput', (state: CategorizationState) => modelOutput({ state }))
.addNode('handleCategorization', (state: CategorizationState) =>
handleCategorization(state, model)
handleCategorization({ state, model })
)
.addNode('handleValidatePipeline', (state: CategorizationState) =>
handleValidatePipeline(state, client)
handleValidatePipeline({ state, client })
)
.addNode('handleCategorizationValidation', (state: CategorizationState) =>
handleCategorizationValidation({ state })
)
.addNode('handleCategorizationValidation', handleCategorizationValidation)
.addNode('handleInvalidCategorization', (state: CategorizationState) =>
handleInvalidCategorization(state, model)
handleInvalidCategorization({ state, model })
)
.addNode('handleErrors', (state: CategorizationState) => handleErrors(state, model))
.addNode('handleReview', (state: CategorizationState) => handleReview(state, model))
.addNode('handleErrors', (state: CategorizationState) => handleErrors({ state, model }))
.addNode('handleReview', (state: CategorizationState) => handleReview({ state, model }))
.addEdge(START, 'modelInput')
.addEdge('modelOutput', END)
.addEdge('modelInput', 'handleValidatePipeline')
.addEdge('handleCategorization', 'handleValidatePipeline')
.addEdge('handleInvalidCategorization', 'handleValidatePipeline')
.addEdge('handleErrors', 'handleValidatePipeline')
.addEdge('handleReview', 'handleValidatePipeline')
.addConditionalEdges('handleValidatePipeline', validationRouter, {
categorization: 'handleCategorization',
validateCategorization: 'handleCategorizationValidation',
})
.addConditionalEdges('handleCategorizationValidation', chainRouter, {
modelOutput: 'modelOutput',
errors: 'handleErrors',
invalidCategorization: 'handleInvalidCategorization',
review: 'handleReview',
});
.addConditionalEdges(
'handleValidatePipeline',
(state: CategorizationState) => validationRouter({ state }),
{
categorization: 'handleCategorization',
validateCategorization: 'handleCategorizationValidation',
}
)
.addConditionalEdges(
'handleCategorizationValidation',
(state: CategorizationState) => chainRouter({ state }),
{
modelOutput: 'modelOutput',
errors: 'handleErrors',
invalidCategorization: 'handleInvalidCategorization',
review: 'handleReview',
}
);

const compiledCategorizationGraph = workflow.compile();
return compiledCategorizationGraph;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;

describe('Testing categorization handler', () => {
it('handleInvalidCategorization()', async () => {
const response = await handleInvalidCategorization(testState, mockLlm);
const response = await handleInvalidCategorization({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { CategorizationNodeParams } from './types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import { combineProcessors } from '../../util/processors';
import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants';
import { CATEGORIZATION_VALIDATION_PROMPT } from './prompts';

export async function handleInvalidCategorization(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleInvalidCategorization({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationInvalidPrompt = CATEGORIZATION_VALIDATION_PROMPT;

const outputParser = new JsonOutputParser();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;

const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;

describe('Testing categorization handler', () => {
it('handleReview()', async () => {
const response = await handleReview(testState, mockLlm);
const response = await handleReview({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';

import { JsonOutputParser } from '@langchain/core/output_parsers';
import { CATEGORIZATION_REVIEW_PROMPT } from './prompts';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { CategorizationNodeParams } from './types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import { combineProcessors } from '../../util/processors';
import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants';

export async function handleReview(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleReview({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationReviewPrompt = CATEGORIZATION_REVIEW_PROMPT;
const outputParser = new JsonOutputParser();
const categorizationReview = categorizationReviewPrompt.pipe(model).pipe(outputParser);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import type { CategorizationState, ChatModels } from '../../types';

export interface CategorizationBaseNodeParams {
state: CategorizationState;
}

export interface CategorizationNodeParams extends CategorizationBaseNodeParams {
model: ChatModels;
}

export interface CategorizationGraphParams {
model: ChatModels;
client: IScopedClusterClient;
}
Loading

0 comments on commit 8c908bc

Please sign in to comment.