From 9228c83bcc7d06331e229cf5cb24110531a5f0b3 Mon Sep 17 00:00:00 2001 From: Marius Iversen Date: Mon, 26 Aug 2024 13:15:20 +0200 Subject: [PATCH] add param types to all graphs, preparing for supporting more connectors --- .../categorization/categorization.test.ts | 6 +- .../graphs/categorization/categorization.ts | 15 ++--- .../graphs/categorization/errors.test.ts | 6 +- .../server/graphs/categorization/errors.ts | 16 ++--- .../graphs/categorization/graph.test.ts | 14 ++-- .../server/graphs/categorization/graph.ts | 65 ++++++++++--------- .../graphs/categorization/invalid.test.ts | 6 +- .../server/graphs/categorization/invalid.ts | 16 ++--- .../graphs/categorization/review.test.ts | 6 +- .../server/graphs/categorization/review.ts | 16 ++--- .../server/graphs/categorization/types.ts | 22 +++++++ .../graphs/categorization/validate.test.ts | 14 ++-- .../server/graphs/categorization/validate.ts | 9 ++- .../server/graphs/ecs/duplicates.test.ts | 6 +- .../server/graphs/ecs/duplicates.ts | 14 ++-- .../server/graphs/ecs/graph.test.ts | 16 +++-- .../server/graphs/ecs/graph.ts | 37 +++++------ .../server/graphs/ecs/index.ts | 1 + .../server/graphs/ecs/invalid.test.ts | 6 +- .../server/graphs/ecs/invalid.ts | 13 ++-- .../server/graphs/ecs/mapping.test.ts | 6 +- .../server/graphs/ecs/mapping.ts | 14 ++-- .../server/graphs/ecs/missing.test.ts | 6 +- .../server/graphs/ecs/missing.ts | 16 ++--- .../server/graphs/ecs/model.ts | 7 +- .../server/graphs/ecs/types.ts | 19 ++++++ .../server/graphs/ecs/validate.ts | 4 +- .../server/graphs/related/errors.test.ts | 6 +- .../server/graphs/related/errors.ts | 14 ++-- .../server/graphs/related/graph.test.ts | 14 ++-- .../server/graphs/related/graph.ts | 49 +++++++------- .../server/graphs/related/related.test.ts | 6 +- .../server/graphs/related/related.ts | 14 ++-- .../server/graphs/related/review.test.ts | 6 +- .../server/graphs/related/review.ts | 14 ++-- .../server/graphs/related/types.ts | 22 +++++++ .../server/routes/categorization_routes.ts | 2 +- .../server/routes/ecs_routes.ts | 2 +- .../server/routes/related_routes.ts | 2 +- .../integration_assistant/server/types.ts | 12 ++++ .../server/util/graph.ts | 13 ++-- 41 files changed, 307 insertions(+), 245 deletions(-) create mode 100644 x-pack/plugins/integration_assistant/server/graphs/categorization/types.ts create mode 100644 x-pack/plugins/integration_assistant/server/graphs/ecs/types.ts create mode 100644 x-pack/plugins/integration_assistant/server/graphs/related/types.ts diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.test.ts index 3ad0926297bbc2..cfa5517ab0f901 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.test.ts @@ -18,15 +18,15 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; -const mockLlm = new FakeLLM({ +const model = new FakeLLM({ response: JSON.stringify(categorizationMockProcessors, null, 2), }) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; -const testState: CategorizationState = categorizationTestState; +const state: CategorizationState = categorizationTestState; describe('Testing categorization handler', () => { it('handleCategorization()', async () => { - const response = await handleCategorization(testState, mockLlm); + const response = await handleCategorization({ state, model }); expect(response.currentPipeline).toStrictEqual( categorizationExpectedHandlerResponse.currentPipeline ); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts index 7f6e083d7f83fd..80e3bb4861d70d 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/categorization.ts @@ -4,21 +4,18 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { Pipeline } from '../../../common'; -import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types'; +import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types'; +import type { CategorizationNodeParams } from './types'; import { combineProcessors } from '../../util/processors'; import { CATEGORIZATION_MAIN_PROMPT } from './prompts'; import { CATEGORIZATION_EXAMPLE_PROCESSORS } from './constants'; -export async function handleCategorization( - state: CategorizationState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleCategorization({ + state, + model, +}: CategorizationNodeParams): Promise> { const categorizationMainPrompt = CATEGORIZATION_MAIN_PROMPT; const outputParser = new JsonOutputParser(); const categorizationMainGraph = categorizationMainPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.test.ts index 18d8c1842080aa..184c6c4988ad47 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.test.ts @@ -18,15 +18,15 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; -const mockLlm = new FakeLLM({ +const model = new FakeLLM({ response: JSON.stringify(categorizationMockProcessors, null, 2), }) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; -const testState: CategorizationState = categorizationTestState; +const state: CategorizationState = categorizationTestState; describe('Testing categorization handler', () => { it('handleErrors()', async () => { - const response = await handleErrors(testState, mockLlm); + const response = await handleErrors({ state, model }); expect(response.currentPipeline).toStrictEqual( categorizationExpectedHandlerResponse.currentPipeline ); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts index e875754cb823d0..789673af0ff286 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/errors.ts @@ -4,20 +4,18 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { Pipeline } from '../../../common'; -import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types'; +import type { CategorizationNodeParams } from './types'; +import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { CATEGORIZATION_ERROR_PROMPT } from './prompts'; -export async function handleErrors( - state: CategorizationState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleErrors({ + state, + model, +}: CategorizationNodeParams): Promise> { const categorizationErrorPrompt = CATEGORIZATION_ERROR_PROMPT; const outputParser = new JsonOutputParser(); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.test.ts index 8fc617120be664..8db8a8019a1ed7 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.test.ts @@ -31,7 +31,7 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; -const mockLlm = new FakeLLM({ +const model = new FakeLLM({ response: "I'll callback later.", }) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; @@ -45,7 +45,7 @@ jest.mock('../../util/pipeline', () => ({ })); describe('runCategorizationGraph', () => { - const mockClient = { + const client = { asCurrentUser: { ingest: { simulate: jest.fn(), @@ -131,14 +131,14 @@ describe('runCategorizationGraph', () => { it('Ensures that the graph compiles', async () => { try { - await getCategorizationGraph(mockClient, mockLlm); + await getCategorizationGraph({ client, model }); } catch (error) { - // noop + throw Error(`getCategorizationGraph threw an error: ${error}`); } }); it('Runs the whole graph, with mocked outputs from the LLM.', async () => { - const categorizationGraph = await getCategorizationGraph(mockClient, mockLlm); + const categorizationGraph = await getCategorizationGraph({ client, model }); (testPipeline as jest.Mock) .mockResolvedValueOnce(testPipelineValidResult) @@ -151,8 +151,8 @@ describe('runCategorizationGraph', () => { let response; try { response = await categorizationGraph.invoke(mockedRequestWithPipeline); - } catch (e) { - // noop + } catch (error) { + throw Error(`getCategorizationGraph threw an error: ${error}`); } expect(response.results).toStrictEqual(categorizationExpectedResults); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts index ff170a23fdf7a6..29e90a1f9b35d4 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/graph.ts @@ -5,14 +5,10 @@ * 2.0. */ -import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server'; import type { StateGraphArgs } from '@langchain/langgraph'; import { StateGraph, END, START } from '@langchain/langgraph'; -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; import type { CategorizationState } from '../../types'; +import type { CategorizationGraphParams, CategorizationBaseNodeParams } from './types'; import { prefixSamples, formatSamples } from '../../util/samples'; import { handleCategorization } from './categorization'; import { handleValidatePipeline } from '../../util/graph'; @@ -105,7 +101,7 @@ const graphState: StateGraphArgs['channels'] = { }, }; -function modelInput(state: CategorizationState): Partial { +function modelInput({ state }: CategorizationBaseNodeParams): Partial { const samples = prefixSamples(state); const formattedSamples = formatSamples(samples); const initialPipeline = JSON.parse(JSON.stringify(state.currentPipeline)); @@ -122,7 +118,7 @@ function modelInput(state: CategorizationState): Partial { }; } -function modelOutput(state: CategorizationState): Partial { +function modelOutput({ state }: CategorizationBaseNodeParams): Partial { return { finalized: true, lastExecutedChain: 'modelOutput', @@ -133,14 +129,14 @@ function modelOutput(state: CategorizationState): Partial { }; } -function validationRouter(state: CategorizationState): string { +function validationRouter({ state }: CategorizationBaseNodeParams): string { if (Object.keys(state.currentProcessors).length === 0) { return 'categorization'; } return 'validateCategorization'; } -function chainRouter(state: CategorizationState): string { +function chainRouter({ state }: CategorizationBaseNodeParams): string { if (Object.keys(state.errors).length > 0) { return 'errors'; } @@ -157,27 +153,26 @@ function chainRouter(state: CategorizationState): string { return END; } -export async function getCategorizationGraph( - client: IScopedClusterClient, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function getCategorizationGraph({ client, model }: CategorizationGraphParams) { const workflow = new StateGraph({ channels: graphState, }) - .addNode('modelInput', modelInput) - .addNode('modelOutput', modelOutput) + .addNode('modelInput', (state: CategorizationState) => modelInput({ state })) + .addNode('modelOutput', (state: CategorizationState) => modelOutput({ state })) .addNode('handleCategorization', (state: CategorizationState) => - handleCategorization(state, model) + handleCategorization({ state, model }) ) .addNode('handleValidatePipeline', (state: CategorizationState) => - handleValidatePipeline(state, client) + handleValidatePipeline({ state, client }) + ) + .addNode('handleCategorizationValidation', (state: CategorizationState) => + handleCategorizationValidation({ state }) ) - .addNode('handleCategorizationValidation', handleCategorizationValidation) .addNode('handleInvalidCategorization', (state: CategorizationState) => - handleInvalidCategorization(state, model) + handleInvalidCategorization({ state, model }) ) - .addNode('handleErrors', (state: CategorizationState) => handleErrors(state, model)) - .addNode('handleReview', (state: CategorizationState) => handleReview(state, model)) + .addNode('handleErrors', (state: CategorizationState) => handleErrors({ state, model })) + .addNode('handleReview', (state: CategorizationState) => handleReview({ state, model })) .addEdge(START, 'modelInput') .addEdge('modelOutput', END) .addEdge('modelInput', 'handleValidatePipeline') @@ -185,16 +180,24 @@ export async function getCategorizationGraph( .addEdge('handleInvalidCategorization', 'handleValidatePipeline') .addEdge('handleErrors', 'handleValidatePipeline') .addEdge('handleReview', 'handleValidatePipeline') - .addConditionalEdges('handleValidatePipeline', validationRouter, { - categorization: 'handleCategorization', - validateCategorization: 'handleCategorizationValidation', - }) - .addConditionalEdges('handleCategorizationValidation', chainRouter, { - modelOutput: 'modelOutput', - errors: 'handleErrors', - invalidCategorization: 'handleInvalidCategorization', - review: 'handleReview', - }); + .addConditionalEdges( + 'handleValidatePipeline', + (state: CategorizationState) => validationRouter({ state }), + { + categorization: 'handleCategorization', + validateCategorization: 'handleCategorizationValidation', + } + ) + .addConditionalEdges( + 'handleCategorizationValidation', + (state: CategorizationState) => chainRouter({ state }), + { + modelOutput: 'modelOutput', + errors: 'handleErrors', + invalidCategorization: 'handleInvalidCategorization', + review: 'handleReview', + } + ); const compiledCategorizationGraph = workflow.compile(); return compiledCategorizationGraph; diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.test.ts index 10560137093d84..35069c64902dd2 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.test.ts @@ -18,15 +18,15 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; -const mockLlm = new FakeLLM({ +const model = new FakeLLM({ response: JSON.stringify(categorizationMockProcessors, null, 2), }) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; -const testState: CategorizationState = categorizationTestState; +const state: CategorizationState = categorizationTestState; describe('Testing categorization handler', () => { it('handleInvalidCategorization()', async () => { - const response = await handleInvalidCategorization(testState, mockLlm); + const response = await handleInvalidCategorization({ state, model }); expect(response.currentPipeline).toStrictEqual( categorizationExpectedHandlerResponse.currentPipeline ); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts index 2ecbd5d34eaa4f..62f7f3101ba9a6 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/invalid.ts @@ -4,21 +4,19 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { Pipeline } from '../../../common'; -import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types'; +import type { CategorizationNodeParams } from './types'; +import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants'; import { CATEGORIZATION_VALIDATION_PROMPT } from './prompts'; -export async function handleInvalidCategorization( - state: CategorizationState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleInvalidCategorization({ + state, + model, +}: CategorizationNodeParams): Promise> { const categorizationInvalidPrompt = CATEGORIZATION_VALIDATION_PROMPT; const outputParser = new JsonOutputParser(); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/review.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/review.test.ts index 7775b69c5b6a81..4294aa6b034f4f 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/review.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/review.test.ts @@ -18,15 +18,15 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; -const mockLlm = new FakeLLM({ +const model = new FakeLLM({ response: JSON.stringify(categorizationMockProcessors, null, 2), }) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; -const testState: CategorizationState = categorizationTestState; +const state: CategorizationState = categorizationTestState; describe('Testing categorization handler', () => { it('handleReview()', async () => { - const response = await handleReview(testState, mockLlm); + const response = await handleReview({ state, model }); expect(response.currentPipeline).toStrictEqual( categorizationExpectedHandlerResponse.currentPipeline ); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts index 7986b4d6c24234..19b8180ce33e57 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/review.ts @@ -4,21 +4,19 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + import { JsonOutputParser } from '@langchain/core/output_parsers'; import { CATEGORIZATION_REVIEW_PROMPT } from './prompts'; import type { Pipeline } from '../../../common'; -import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types'; +import type { CategorizationNodeParams } from './types'; +import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types'; import { combineProcessors } from '../../util/processors'; import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants'; -export async function handleReview( - state: CategorizationState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleReview({ + state, + model, +}: CategorizationNodeParams): Promise> { const categorizationReviewPrompt = CATEGORIZATION_REVIEW_PROMPT; const outputParser = new JsonOutputParser(); const categorizationReview = categorizationReviewPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/types.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/types.ts new file mode 100644 index 00000000000000..19b1c20dff7555 --- /dev/null +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/types.ts @@ -0,0 +1,22 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server'; +import type { CategorizationState, ChatModels } from '../../types'; + +export interface CategorizationBaseNodeParams { + state: CategorizationState; +} + +export interface CategorizationNodeParams extends CategorizationBaseNodeParams { + model: ChatModels; +} + +export interface CategorizationGraphParams { + model: ChatModels; + client: IScopedClusterClient; +} diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/validate.test.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/validate.test.ts index 95c56c777a315c..0fe546c1e21b35 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/validate.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/validate.test.ts @@ -9,12 +9,12 @@ import { handleCategorizationValidation } from './validate'; import type { CategorizationState } from '../../types'; import { categorizationTestState } from '../../../__jest__/fixtures/categorization'; -const testState: CategorizationState = categorizationTestState; +const state: CategorizationState = categorizationTestState; describe('Testing categorization invalid category', () => { it('handleCategorizationValidation()', async () => { - testState.pipelineResults = [{ test: 'testresult', event: { category: ['foo'] } }]; - const response = handleCategorizationValidation(testState); + state.pipelineResults = [{ test: 'testresult', event: { category: ['foo'] } }]; + const response = handleCategorizationValidation({ state }); expect(response.invalidCategorization).toEqual([ { error: @@ -27,8 +27,8 @@ describe('Testing categorization invalid category', () => { describe('Testing categorization invalid type', () => { it('handleCategorizationValidation()', async () => { - testState.pipelineResults = [{ test: 'testresult', event: { type: ['foo'] } }]; - const response = handleCategorizationValidation(testState); + state.pipelineResults = [{ test: 'testresult', event: { type: ['foo'] } }]; + const response = handleCategorizationValidation({ state }); expect(response.invalidCategorization).toEqual([ { error: @@ -41,10 +41,10 @@ describe('Testing categorization invalid type', () => { describe('Testing categorization invalid compatibility', () => { it('handleCategorizationValidation()', async () => { - testState.pipelineResults = [ + state.pipelineResults = [ { test: 'testresult', event: { category: ['authentication'], type: ['access'] } }, ]; - const response = handleCategorizationValidation(testState); + const response = handleCategorizationValidation({ state }); expect(response.invalidCategorization).toEqual([ { error: 'event.type (access) not compatible with any of the event.category (authentication)', diff --git a/x-pack/plugins/integration_assistant/server/graphs/categorization/validate.ts b/x-pack/plugins/integration_assistant/server/graphs/categorization/validate.ts index af3846edbac349..6360f327521c50 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/categorization/validate.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/categorization/validate.ts @@ -5,6 +5,7 @@ * 2.0. */ import type { CategorizationState } from '../../types'; +import type { CategorizationBaseNodeParams } from './types'; import { ECS_EVENT_TYPES_PER_CATEGORY, EVENT_CATEGORIES, EVENT_TYPES } from './constants'; import type { EventCategories } from './constants'; @@ -22,11 +23,9 @@ interface CategorizationError { error: string; } -export function handleCategorizationValidation(state: CategorizationState): { - previousInvalidCategorization: string; - invalidCategorization: CategorizationError[]; - lastExecutedChain: string; -} { +export function handleCategorizationValidation({ + state, +}: CategorizationBaseNodeParams): Partial { let previousInvalidCategorization = ''; const errors: CategorizationError[] = []; const pipelineResults = state.pipelineResults as PipelineResult[]; diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.test.ts index 9270b2453e261f..f3a51c80a52410 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.test.ts @@ -14,15 +14,15 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; -const mockLlm = new FakeLLM({ +const model = new FakeLLM({ response: '{ "message": "ll callback later."}', }) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; -const testState: EcsMappingState = ecsTestState; +const state: EcsMappingState = ecsTestState; describe('Testing ecs handler', () => { it('handleDuplicates()', async () => { - const response = await handleDuplicates(testState, mockLlm); + const response = await handleDuplicates({ state, model }); expect(response.currentMapping).toStrictEqual({ message: 'll callback later.' }); expect(response.lastExecutedChain).toBe('duplicateFields'); }); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts index a9508901860dbd..5c66168fc0bfe0 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/duplicates.ts @@ -4,18 +4,16 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + import { JsonOutputParser } from '@langchain/core/output_parsers'; +import type { EcsNodeParams } from './types'; import type { EcsMappingState } from '../../types'; import { ECS_DUPLICATES_PROMPT } from './prompts'; -export async function handleDuplicates( - state: EcsMappingState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleDuplicates({ + state, + model, +}: EcsNodeParams): Promise> { const outputParser = new JsonOutputParser(); const ecsDuplicatesGraph = ECS_DUPLICATES_PROMPT.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.test.ts index 62e51e7d68c71a..322d71ef4c7926 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.test.ts @@ -24,7 +24,7 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; -const mockLlm = new FakeLLM({ +const model = new FakeLLM({ response: "I'll callback later.", }) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; @@ -69,16 +69,22 @@ describe('EcsGraph', () => { // When getEcsGraph runs, langgraph compiles the graph it will error if the graph has any issues. // Common issues for example detecting a node has no next step, or there is a infinite loop between them. try { - await getEcsGraph(mockLlm); + await getEcsGraph({ model }); } catch (error) { - fail(`getEcsGraph threw an error: ${error}`); + throw Error(`getEcsGraph threw an error: ${error}`); } }); it('Runs the whole graph, with mocked outputs from the LLM.', async () => { // The mocked outputs are specifically crafted to trigger ALL different conditions, allowing us to test the whole graph. // This is why we have all the expects ensuring each function was called. - const ecsGraph = await getEcsGraph(mockLlm); - const response = await ecsGraph.invoke(mockedRequest); + const ecsGraph = await getEcsGraph({ model }); + let response; + try { + response = await ecsGraph.invoke(mockedRequest); + } catch (error) { + throw Error(`getEcsGraph threw an error: ${error}`); + } + expect(response.results).toStrictEqual(ecsMappingExpectedResults); // Check if the functions were called diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts index 4b1e4c4c37791a..86b0561a22e44e 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts @@ -5,12 +5,9 @@ * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; import { END, START, StateGraph, Send } from '@langchain/langgraph'; import type { EcsMappingState } from '../../types'; +import type { EcsGraphParams, EcsBaseNodeParams } from './types'; import { modelInput, modelOutput, modelSubOutput } from './model'; import { handleDuplicates } from './duplicates'; import { handleInvalidEcs } from './invalid'; @@ -19,7 +16,7 @@ import { handleMissingKeys } from './missing'; import { handleValidateMappings } from './validate'; import { graphState } from './state'; -const handleCreateMappingChunks = async (state: EcsMappingState) => { +const handleCreateMappingChunks = async ({ state }: EcsBaseNodeParams) => { // Cherrypick a shallow copy of state to pass to subgraph const stateParams = { exAnswer: state.exAnswer, @@ -36,7 +33,7 @@ const handleCreateMappingChunks = async (state: EcsMappingState) => { return 'modelOutput'; }; -function chainRouter(state: EcsMappingState): string { +function chainRouter({ state }: EcsBaseNodeParams): string { if (Object.keys(state.duplicateFields).length > 0) { return 'duplicateFields'; } @@ -53,22 +50,22 @@ function chainRouter(state: EcsMappingState): string { } // This is added as a separate graph to be able to run these steps concurrently from handleCreateMappingChunks -async function getEcsSubGraph(model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { +async function getEcsSubGraph({ model }: EcsGraphParams) { const workflow = new StateGraph({ channels: graphState, }) - .addNode('modelSubOutput', modelSubOutput) - .addNode('handleValidation', handleValidateMappings) - .addNode('handleEcsMapping', (state: EcsMappingState) => handleEcsMapping(state, model)) - .addNode('handleDuplicates', (state: EcsMappingState) => handleDuplicates(state, model)) - .addNode('handleMissingKeys', (state: EcsMappingState) => handleMissingKeys(state, model)) - .addNode('handleInvalidEcs', (state: EcsMappingState) => handleInvalidEcs(state, model)) + .addNode('modelSubOutput', (state: EcsMappingState) => modelSubOutput({ state })) + .addNode('handleValidation', (state: EcsMappingState) => handleValidateMappings({ state })) + .addNode('handleEcsMapping', (state: EcsMappingState) => handleEcsMapping({ state, model })) + .addNode('handleDuplicates', (state: EcsMappingState) => handleDuplicates({ state, model })) + .addNode('handleMissingKeys', (state: EcsMappingState) => handleMissingKeys({ state, model })) + .addNode('handleInvalidEcs', (state: EcsMappingState) => handleInvalidEcs({ state, model })) .addEdge(START, 'handleEcsMapping') .addEdge('handleEcsMapping', 'handleValidation') .addEdge('handleDuplicates', 'handleValidation') .addEdge('handleMissingKeys', 'handleValidation') .addEdge('handleInvalidEcs', 'handleValidation') - .addConditionalEdges('handleValidation', chainRouter, { + .addConditionalEdges('handleValidation', (state: EcsMappingState) => chainRouter({ state }), { duplicateFields: 'handleDuplicates', missingKeys: 'handleMissingKeys', invalidEcsFields: 'handleInvalidEcs', @@ -81,17 +78,19 @@ async function getEcsSubGraph(model: ActionsClientChatOpenAI | ActionsClientSimp return compiledEcsSubGraph; } -export async function getEcsGraph(model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) { - const subGraph = await getEcsSubGraph(model); +export async function getEcsGraph({ model }: EcsGraphParams) { + const subGraph = await getEcsSubGraph({ model }); const workflow = new StateGraph({ channels: graphState, }) - .addNode('modelInput', modelInput) - .addNode('modelOutput', modelOutput) + .addNode('modelInput', (state: EcsMappingState) => modelInput({ state })) + .addNode('modelOutput', (state: EcsMappingState) => modelOutput({ state })) .addNode('subGraph', subGraph) .addEdge(START, 'modelInput') .addEdge('subGraph', 'modelOutput') - .addConditionalEdges('modelInput', handleCreateMappingChunks) + .addConditionalEdges('modelInput', (state: EcsMappingState) => + handleCreateMappingChunks({ state }) + ) .addEdge('modelOutput', END); const compiledEcsGraph = workflow.compile(); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/index.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/index.ts index 91ea9fed3b3d3a..4207727a315e5b 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/index.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/index.ts @@ -4,4 +4,5 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ + export { getEcsGraph } from './graph'; diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.test.ts index ce1f76ce7a7212..ad10aa5b030dfb 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.test.ts @@ -14,15 +14,15 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; -const mockLlm = new FakeLLM({ +const model = new FakeLLM({ response: '{ "message": "ll callback later."}', }) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; -const testState: EcsMappingState = ecsTestState; +const state: EcsMappingState = ecsTestState; describe('Testing ecs handlers', () => { it('handleInvalidEcs()', async () => { - const response = await handleInvalidEcs(testState, mockLlm); + const response = await handleInvalidEcs({ state, model }); expect(response.currentMapping).toStrictEqual({ message: 'll callback later.' }); expect(response.lastExecutedChain).toBe('invalidEcs'); }); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts index 8e2d1baf4c4232..4b050fac3ccf41 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/invalid.ts @@ -4,18 +4,15 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; import { JsonOutputParser } from '@langchain/core/output_parsers'; +import type { EcsNodeParams } from './types'; import type { EcsMappingState } from '../../types'; import { ECS_INVALID_PROMPT } from './prompts'; -export async function handleInvalidEcs( - state: EcsMappingState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleInvalidEcs({ + state, + model, +}: EcsNodeParams): Promise> { const outputParser = new JsonOutputParser(); const ecsInvalidEcsGraph = ECS_INVALID_PROMPT.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.test.ts index dbbfc0608d0101..92954b83863bfb 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.test.ts @@ -14,15 +14,15 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; -const mockLlm = new FakeLLM({ +const model = new FakeLLM({ response: '{ "message": "ll callback later."}', }) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; -const testState: EcsMappingState = ecsTestState; +const state: EcsMappingState = ecsTestState; describe('Testing ecs handler', () => { it('handleEcsMapping()', async () => { - const response = await handleEcsMapping(testState, mockLlm); + const response = await handleEcsMapping({ state, model }); expect(response.currentMapping).toStrictEqual({ message: 'll callback later.' }); expect(response.lastExecutedChain).toBe('ecsMapping'); }); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts index 30c51dcc01dd92..7e8d9643237434 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/mapping.ts @@ -4,18 +4,16 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + import { JsonOutputParser } from '@langchain/core/output_parsers'; +import type { EcsNodeParams } from './types'; import type { EcsMappingState } from '../../types'; import { ECS_MAIN_PROMPT } from './prompts'; -export async function handleEcsMapping( - state: EcsMappingState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleEcsMapping({ + state, + model, +}: EcsNodeParams): Promise> { const outputParser = new JsonOutputParser(); const ecsMainGraph = ECS_MAIN_PROMPT.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.test.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.test.ts index b369d28b1e177c..35fbc51bbb2e7b 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.test.ts @@ -14,15 +14,15 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; -const mockLlm = new FakeLLM({ +const model = new FakeLLM({ response: '{ "message": "ll callback later."}', }) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; -const testState: EcsMappingState = ecsTestState; +const state: EcsMappingState = ecsTestState; describe('Testing ecs handler', () => { it('handleMissingKeys()', async () => { - const response = await handleMissingKeys(testState, mockLlm); + const response = await handleMissingKeys({ state, model }); expect(response.currentMapping).toStrictEqual({ message: 'll callback later.' }); expect(response.lastExecutedChain).toBe('missingKeys'); }); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts index 0a23b35bd3b723..649c9a5d1facfe 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/missing.ts @@ -4,18 +4,16 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + import { JsonOutputParser } from '@langchain/core/output_parsers'; -import type { EcsMappingState } from '../../types'; +import { EcsNodeParams } from './types'; +import { EcsMappingState } from '../../types'; import { ECS_MISSING_KEYS_PROMPT } from './prompts'; -export async function handleMissingKeys( - state: EcsMappingState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleMissingKeys({ + state, + model, +}: EcsNodeParams): Promise> { const outputParser = new JsonOutputParser(); const ecsMissingGraph = ECS_MISSING_KEYS_PROMPT.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/model.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/model.ts index 9bc2909ab79425..44508bca4ff1a1 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/model.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/model.ts @@ -9,15 +9,16 @@ import { ECS_EXAMPLE_ANSWER, ECS_FIELDS } from './constants'; import { createPipeline } from './pipeline'; import { mergeAndChunkSamples } from './chunk'; import type { EcsMappingState } from '../../types'; +import type { EcsBaseNodeParams } from './types'; -export function modelSubOutput(state: EcsMappingState): Partial { +export function modelSubOutput({ state }: EcsBaseNodeParams): Partial { return { lastExecutedChain: 'ModelSubOutput', finalMapping: state.currentMapping, }; } -export function modelInput(state: EcsMappingState): Partial { +export function modelInput({ state }: EcsBaseNodeParams): Partial { const prefixedSamples = prefixSamples(state); const sampleChunks = mergeAndChunkSamples(prefixedSamples, state.chunkSize); return { @@ -30,7 +31,7 @@ export function modelInput(state: EcsMappingState): Partial { }; } -export function modelOutput(state: EcsMappingState): Partial { +export function modelOutput({ state }: EcsBaseNodeParams): Partial { const currentPipeline = createPipeline(state); return { finalized: true, diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/types.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/types.ts new file mode 100644 index 00000000000000..a9188d3985ac75 --- /dev/null +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/types.ts @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import type { EcsMappingState, ChatModels } from '../../types'; + +export interface EcsBaseNodeParams { + state: EcsMappingState; +} + +export interface EcsNodeParams extends EcsBaseNodeParams { + model: ChatModels; +} + +export interface EcsGraphParams { + model: ChatModels; +} diff --git a/x-pack/plugins/integration_assistant/server/graphs/ecs/validate.ts b/x-pack/plugins/integration_assistant/server/graphs/ecs/validate.ts index f347247df42460..6c3fe06699aa31 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/ecs/validate.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/ecs/validate.ts @@ -6,7 +6,7 @@ */ /* eslint-disable @typescript-eslint/no-explicit-any */ import { ECS_FULL } from '../../../common/ecs'; -import type { EcsMappingState } from '../../types'; +import type { EcsBaseNodeParams } from './types'; import { ECS_RESERVED } from './constants'; const valueFieldKeys = new Set(['target', 'confidence', 'date_formats', 'type']); @@ -152,7 +152,7 @@ export function findInvalidEcsFields(currentMapping: AnyObject): string[] { return results; } -export function handleValidateMappings(state: EcsMappingState): AnyObject { +export function handleValidateMappings({ state }: EcsBaseNodeParams): AnyObject { const missingKeys = findMissingFields(state?.combinedSamples, state?.currentMapping); const duplicateFields = findDuplicateFields(state?.prefixedSamples, state?.currentMapping); const invalidEcsFields = findInvalidEcsFields(state?.currentMapping); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/errors.test.ts b/x-pack/plugins/integration_assistant/server/graphs/related/errors.test.ts index 24dc4365dcbfff..719c3f6abfc222 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/errors.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/errors.test.ts @@ -18,15 +18,15 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; -const mockLlm = new FakeLLM({ +const model = new FakeLLM({ response: JSON.stringify(relatedMockProcessors, null, 2), }) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; -const testState: RelatedState = relatedTestState; +const state: RelatedState = relatedTestState; describe('Testing related handler', () => { it('handleErrors()', async () => { - const response = await handleErrors(testState, mockLlm); + const response = await handleErrors({ state, model }); expect(response.currentPipeline).toStrictEqual(relatedExpectedHandlerResponse.currentPipeline); expect(response.lastExecutedChain).toBe('error'); }); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts b/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts index b40e638751ee0a..5601c4b5f5e333 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/errors.ts @@ -4,21 +4,19 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { Pipeline } from '../../../common'; import type { RelatedState, SimplifiedProcessors, SimplifiedProcessor } from '../../types'; +import type { RelatedNodeParams } from './types'; import { combineProcessors } from '../../util/processors'; import { RELATED_ERROR_PROMPT } from './prompts'; import { COMMON_ERRORS } from './constants'; -export async function handleErrors( - state: RelatedState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleErrors({ + state, + model, +}: RelatedNodeParams): Promise> { const relatedErrorPrompt = RELATED_ERROR_PROMPT; const outputParser = new JsonOutputParser(); const relatedErrorGraph = relatedErrorPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/graph.test.ts b/x-pack/plugins/integration_assistant/server/graphs/related/graph.test.ts index a07a715a179e19..9583a3050a38ae 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/graph.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/graph.test.ts @@ -28,7 +28,7 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; -const mockLlm = new FakeLLM({ +const model = new FakeLLM({ response: "I'll callback later.", }) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; @@ -41,7 +41,7 @@ jest.mock('../../util/pipeline', () => ({ })); describe('runRelatedGraph', () => { - const mockClient = { + const client = { asCurrentUser: { indices: { getMapping: jest.fn(), @@ -106,14 +106,14 @@ describe('runRelatedGraph', () => { it('Ensures that the graph compiles', async () => { try { - await getRelatedGraph(mockClient, mockLlm); + await getRelatedGraph({ client, model }); } catch (error) { - // noop + throw Error(`getRelatedGraph threw an error: ${error}`); } }); it('Runs the whole graph, with mocked outputs from the LLM.', async () => { - const relatedGraph = await getRelatedGraph(mockClient, mockLlm); + const relatedGraph = await getRelatedGraph({ client, model }); (testPipeline as jest.Mock) .mockResolvedValueOnce(testPipelineValidResult) @@ -125,8 +125,8 @@ describe('runRelatedGraph', () => { let response; try { response = await relatedGraph.invoke(mockedRequestWithPipeline); - } catch (e) { - // noop + } catch (error) { + throw Error(`getRelatedGraph threw an error: ${error}`); } expect(response.results).toStrictEqual(relatedExpectedResults); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts b/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts index 22eb69f7d2a2d8..eb7196b7b4ecb7 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/graph.ts @@ -5,14 +5,10 @@ * 2.0. */ -import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server'; import type { StateGraphArgs } from '@langchain/langgraph'; import { StateGraph, END, START } from '@langchain/langgraph'; -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; import type { RelatedState } from '../../types'; +import type { RelatedGraphParams, RelatedBaseNodeParams } from './types'; import { prefixSamples, formatSamples } from '../../util/samples'; import { handleValidatePipeline } from '../../util/graph'; import { handleRelated } from './related'; @@ -91,7 +87,7 @@ const graphState: StateGraphArgs['channels'] = { }, }; -function modelInput(state: RelatedState): Partial { +function modelInput({ state }: RelatedBaseNodeParams): Partial { const samples = prefixSamples(state); const formattedSamples = formatSamples(samples); const initialPipeline = JSON.parse(JSON.stringify(state.currentPipeline)); @@ -107,7 +103,7 @@ function modelInput(state: RelatedState): Partial { }; } -function modelOutput(state: RelatedState): Partial { +function modelOutput({ state }: RelatedBaseNodeParams): Partial { return { finalized: true, lastExecutedChain: 'modelOutput', @@ -118,14 +114,14 @@ function modelOutput(state: RelatedState): Partial { }; } -function inputRouter(state: RelatedState): string { +function inputRouter({ state }: RelatedBaseNodeParams): string { if (Object.keys(state.pipelineResults).length === 0) { return 'validatePipeline'; } return 'related'; } -function chainRouter(state: RelatedState): string { +function chainRouter({ state }: RelatedBaseNodeParams): string { if (Object.keys(state.currentProcessors).length === 0) { return 'related'; } @@ -141,34 +137,35 @@ function chainRouter(state: RelatedState): string { return END; } -export async function getRelatedGraph( - client: IScopedClusterClient, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function getRelatedGraph({ client, model }: RelatedGraphParams) { const workflow = new StateGraph({ channels: graphState }) - .addNode('modelInput', modelInput) - .addNode('modelOutput', modelOutput) - .addNode('handleRelated', (state: RelatedState) => handleRelated(state, model)) + .addNode('modelInput', (state: RelatedState) => modelInput({ state })) + .addNode('modelOutput', (state: RelatedState) => modelOutput({ state })) + .addNode('handleRelated', (state: RelatedState) => handleRelated({ state, model })) .addNode('handleValidatePipeline', (state: RelatedState) => - handleValidatePipeline(state, client) + handleValidatePipeline({ state, client }) ) - .addNode('handleErrors', (state: RelatedState) => handleErrors(state, model)) - .addNode('handleReview', (state: RelatedState) => handleReview(state, model)) + .addNode('handleErrors', (state: RelatedState) => handleErrors({ state, model })) + .addNode('handleReview', (state: RelatedState) => handleReview({ state, model })) .addEdge(START, 'modelInput') .addEdge('modelOutput', END) .addEdge('handleRelated', 'handleValidatePipeline') .addEdge('handleErrors', 'handleValidatePipeline') .addEdge('handleReview', 'handleValidatePipeline') - .addConditionalEdges('modelInput', inputRouter, { + .addConditionalEdges('modelInput', (state: RelatedState) => inputRouter({ state }), { related: 'handleRelated', validatePipeline: 'handleValidatePipeline', }) - .addConditionalEdges('handleValidatePipeline', chainRouter, { - related: 'handleRelated', - errors: 'handleErrors', - review: 'handleReview', - modelOutput: 'modelOutput', - }); + .addConditionalEdges( + 'handleValidatePipeline', + (state: RelatedState) => chainRouter({ state }), + { + related: 'handleRelated', + errors: 'handleErrors', + review: 'handleReview', + modelOutput: 'modelOutput', + } + ); const compiledRelatedGraph = workflow.compile(); return compiledRelatedGraph; diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/related.test.ts b/x-pack/plugins/integration_assistant/server/graphs/related/related.test.ts index 3a741020fb5303..62a09cfa64ac14 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/related.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/related.test.ts @@ -18,15 +18,15 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; -const mockLlm = new FakeLLM({ +const model = new FakeLLM({ response: JSON.stringify(relatedMockProcessors, null, 2), }) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; -const testState: RelatedState = relatedTestState; +const state: RelatedState = relatedTestState; describe('Testing related handler', () => { it('handleRelated()', async () => { - const response = await handleRelated(testState, mockLlm); + const response = await handleRelated({ state, model }); expect(response.currentPipeline).toStrictEqual(relatedExpectedHandlerResponse.currentPipeline); expect(response.lastExecutedChain).toBe('related'); }); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/related.ts b/x-pack/plugins/integration_assistant/server/graphs/related/related.ts index af3b27790da97b..172270e7f87dad 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/related.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/related.ts @@ -4,20 +4,18 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { Pipeline } from '../../../common'; import type { RelatedState, SimplifiedProcessors, SimplifiedProcessor } from '../../types'; +import type { RelatedNodeParams } from './types'; import { combineProcessors } from '../../util/processors'; import { RELATED_MAIN_PROMPT } from './prompts'; -export async function handleRelated( - state: RelatedState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleRelated({ + state, + model, +}: RelatedNodeParams): Promise> { const relatedMainPrompt = RELATED_MAIN_PROMPT; const outputParser = new JsonOutputParser(); const relatedMainGraph = relatedMainPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/review.test.ts b/x-pack/plugins/integration_assistant/server/graphs/related/review.test.ts index 475f0d72b988d1..2b6085c6f4f86e 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/review.test.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/review.test.ts @@ -18,15 +18,15 @@ import { ActionsClientSimpleChatModel, } from '@kbn/langchain/server/language_models'; -const mockLlm = new FakeLLM({ +const model = new FakeLLM({ response: JSON.stringify(relatedMockProcessors, null, 2), }) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel; -const testState: RelatedState = relatedTestState; +const state: RelatedState = relatedTestState; describe('Testing related handler', () => { it('handleReview()', async () => { - const response = await handleReview(testState, mockLlm); + const response = await handleReview({ state, model }); expect(response.currentPipeline).toStrictEqual(relatedExpectedHandlerResponse.currentPipeline); expect(response.lastExecutedChain).toBe('review'); }); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/review.ts b/x-pack/plugins/integration_assistant/server/graphs/related/review.ts index 31abb6ca5a60cd..300f33144b52ac 100644 --- a/x-pack/plugins/integration_assistant/server/graphs/related/review.ts +++ b/x-pack/plugins/integration_assistant/server/graphs/related/review.ts @@ -4,20 +4,18 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -import type { - ActionsClientChatOpenAI, - ActionsClientSimpleChatModel, -} from '@kbn/langchain/server/language_models'; + import { JsonOutputParser } from '@langchain/core/output_parsers'; import type { Pipeline } from '../../../common'; import type { RelatedState, SimplifiedProcessors, SimplifiedProcessor } from '../../types'; +import type { RelatedNodeParams } from './types'; import { combineProcessors } from '../../util/processors'; import { RELATED_REVIEW_PROMPT } from './prompts'; -export async function handleReview( - state: RelatedState, - model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel -) { +export async function handleReview({ + state, + model, +}: RelatedNodeParams): Promise> { const relatedReviewPrompt = RELATED_REVIEW_PROMPT; const outputParser = new JsonOutputParser(); const relatedReviewGraph = relatedReviewPrompt.pipe(model).pipe(outputParser); diff --git a/x-pack/plugins/integration_assistant/server/graphs/related/types.ts b/x-pack/plugins/integration_assistant/server/graphs/related/types.ts new file mode 100644 index 00000000000000..77f77fbacf6056 --- /dev/null +++ b/x-pack/plugins/integration_assistant/server/graphs/related/types.ts @@ -0,0 +1,22 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server'; +import { RelatedState, ChatModels } from '../../types'; + +export interface RelatedBaseNodeParams { + state: RelatedState; +} + +export interface RelatedNodeParams extends RelatedBaseNodeParams { + model: ChatModels; +} + +export interface RelatedGraphParams { + client: IScopedClusterClient; + model: ChatModels; +} diff --git a/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts b/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts index 80ebe9eb652583..439ebe91db2b6d 100644 --- a/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts +++ b/x-pack/plugins/integration_assistant/server/routes/categorization_routes.ts @@ -92,7 +92,7 @@ export function registerCategorizationRoutes( ], }; - const graph = await getCategorizationGraph(client, model); + const graph = await getCategorizationGraph({ client, model }); const results = await graph.invoke(parameters, options); return res.ok({ body: CategorizationResponse.parse(results) }); diff --git a/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts b/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts index 69b13a004e98ed..78ecf2023858ba 100644 --- a/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts +++ b/x-pack/plugins/integration_assistant/server/routes/ecs_routes.ts @@ -84,7 +84,7 @@ export function registerEcsRoutes(router: IRouter void; @@ -97,3 +103,9 @@ export interface RelatedState { results: object; lastExecutedChain: string; } + +export type ChatModels = + | ActionsClientChatOpenAI + | ActionsClientBedrockChatModel + | ActionsClientSimpleChatModel + | ActionsClientGeminiChatModel; diff --git a/x-pack/plugins/integration_assistant/server/util/graph.ts b/x-pack/plugins/integration_assistant/server/util/graph.ts index a4e8141eae408f..53a7787263ce10 100644 --- a/x-pack/plugins/integration_assistant/server/util/graph.ts +++ b/x-pack/plugins/integration_assistant/server/util/graph.ts @@ -8,10 +8,15 @@ import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server'; import type { CategorizationState, RelatedState } from '../types'; import { testPipeline } from './pipeline'; -export async function handleValidatePipeline( - state: CategorizationState | RelatedState, - client: IScopedClusterClient -): Promise | Partial> { +interface HandleValidateNodeParams { + state: CategorizationState | RelatedState; + client: IScopedClusterClient; +} + +export async function handleValidatePipeline({ + state, + client, +}: HandleValidateNodeParams): Promise | Partial> { const previousError = JSON.stringify(state.errors, null, 2); const results = await testPipeline(state.rawSamples, state.currentPipeline, client); return {