[Automatic Import] Prepare to support more connectors (#191278)

This PR does not add any functionality, it adds interfaces to the
expected parameters from get*Graph and its graph nodes.
This is so it will be much easier extend this later when we might need
to add/switch types over a whole graph like we would have needed when
adding more connectors.

The PR touches a lot of files, but does not add/remove/change any
functionality at all, and the current expected function arguments are
the same, just the format is a bit different to better align with how
other plugins are doing it.



- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios.
This commit is contained in:
Marius Iversen 2024-08-26 19:29:30 +02:00 committed by GitHub
parent a1e216b148
commit 791f638823
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 307 additions and 245 deletions

View file

@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;
describe('Testing categorization handler', () => {
it('handleCategorization()', async () => {
const response = await handleCategorization(testState, mockLlm);
const response = await handleCategorization({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);

View file

@ -4,21 +4,18 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import type { CategorizationNodeParams } from './types';
import { combineProcessors } from '../../util/processors';
import { CATEGORIZATION_MAIN_PROMPT } from './prompts';
import { CATEGORIZATION_EXAMPLE_PROCESSORS } from './constants';
export async function handleCategorization(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleCategorization({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationMainPrompt = CATEGORIZATION_MAIN_PROMPT;
const outputParser = new JsonOutputParser();
const categorizationMainGraph = categorizationMainPrompt.pipe(model).pipe(outputParser);

View file

@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;
describe('Testing categorization handler', () => {
it('handleErrors()', async () => {
const response = await handleErrors(testState, mockLlm);
const response = await handleErrors({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);

View file

@ -4,20 +4,18 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { CategorizationNodeParams } from './types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import { combineProcessors } from '../../util/processors';
import { CATEGORIZATION_ERROR_PROMPT } from './prompts';
export async function handleErrors(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleErrors({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationErrorPrompt = CATEGORIZATION_ERROR_PROMPT;
const outputParser = new JsonOutputParser();

View file

@ -31,7 +31,7 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: "I'll callback later.",
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
@ -45,7 +45,7 @@ jest.mock('../../util/pipeline', () => ({
}));
describe('runCategorizationGraph', () => {
const mockClient = {
const client = {
asCurrentUser: {
ingest: {
simulate: jest.fn(),
@ -131,14 +131,14 @@ describe('runCategorizationGraph', () => {
it('Ensures that the graph compiles', async () => {
try {
await getCategorizationGraph(mockClient, mockLlm);
await getCategorizationGraph({ client, model });
} catch (error) {
// noop
throw Error(`getCategorizationGraph threw an error: ${error}`);
}
});
it('Runs the whole graph, with mocked outputs from the LLM.', async () => {
const categorizationGraph = await getCategorizationGraph(mockClient, mockLlm);
const categorizationGraph = await getCategorizationGraph({ client, model });
(testPipeline as jest.Mock)
.mockResolvedValueOnce(testPipelineValidResult)
@ -151,8 +151,8 @@ describe('runCategorizationGraph', () => {
let response;
try {
response = await categorizationGraph.invoke(mockedRequestWithPipeline);
} catch (e) {
// noop
} catch (error) {
throw Error(`getCategorizationGraph threw an error: ${error}`);
}
expect(response.results).toStrictEqual(categorizationExpectedResults);

View file

@ -5,14 +5,10 @@
* 2.0.
*/
import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import type { StateGraphArgs } from '@langchain/langgraph';
import { StateGraph, END, START } from '@langchain/langgraph';
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import type { CategorizationState } from '../../types';
import type { CategorizationGraphParams, CategorizationBaseNodeParams } from './types';
import { prefixSamples, formatSamples } from '../../util/samples';
import { handleCategorization } from './categorization';
import { handleValidatePipeline } from '../../util/graph';
@ -105,7 +101,7 @@ const graphState: StateGraphArgs<CategorizationState>['channels'] = {
},
};
function modelInput(state: CategorizationState): Partial<CategorizationState> {
function modelInput({ state }: CategorizationBaseNodeParams): Partial<CategorizationState> {
const samples = prefixSamples(state);
const formattedSamples = formatSamples(samples);
const initialPipeline = JSON.parse(JSON.stringify(state.currentPipeline));
@ -122,7 +118,7 @@ function modelInput(state: CategorizationState): Partial<CategorizationState> {
};
}
function modelOutput(state: CategorizationState): Partial<CategorizationState> {
function modelOutput({ state }: CategorizationBaseNodeParams): Partial<CategorizationState> {
return {
finalized: true,
lastExecutedChain: 'modelOutput',
@ -133,14 +129,14 @@ function modelOutput(state: CategorizationState): Partial<CategorizationState> {
};
}
function validationRouter(state: CategorizationState): string {
function validationRouter({ state }: CategorizationBaseNodeParams): string {
if (Object.keys(state.currentProcessors).length === 0) {
return 'categorization';
}
return 'validateCategorization';
}
function chainRouter(state: CategorizationState): string {
function chainRouter({ state }: CategorizationBaseNodeParams): string {
if (Object.keys(state.errors).length > 0) {
return 'errors';
}
@ -157,27 +153,26 @@ function chainRouter(state: CategorizationState): string {
return END;
}
export async function getCategorizationGraph(
client: IScopedClusterClient,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function getCategorizationGraph({ client, model }: CategorizationGraphParams) {
const workflow = new StateGraph({
channels: graphState,
})
.addNode('modelInput', modelInput)
.addNode('modelOutput', modelOutput)
.addNode('modelInput', (state: CategorizationState) => modelInput({ state }))
.addNode('modelOutput', (state: CategorizationState) => modelOutput({ state }))
.addNode('handleCategorization', (state: CategorizationState) =>
handleCategorization(state, model)
handleCategorization({ state, model })
)
.addNode('handleValidatePipeline', (state: CategorizationState) =>
handleValidatePipeline(state, client)
handleValidatePipeline({ state, client })
)
.addNode('handleCategorizationValidation', (state: CategorizationState) =>
handleCategorizationValidation({ state })
)
.addNode('handleCategorizationValidation', handleCategorizationValidation)
.addNode('handleInvalidCategorization', (state: CategorizationState) =>
handleInvalidCategorization(state, model)
handleInvalidCategorization({ state, model })
)
.addNode('handleErrors', (state: CategorizationState) => handleErrors(state, model))
.addNode('handleReview', (state: CategorizationState) => handleReview(state, model))
.addNode('handleErrors', (state: CategorizationState) => handleErrors({ state, model }))
.addNode('handleReview', (state: CategorizationState) => handleReview({ state, model }))
.addEdge(START, 'modelInput')
.addEdge('modelOutput', END)
.addEdge('modelInput', 'handleValidatePipeline')
@ -185,16 +180,24 @@ export async function getCategorizationGraph(
.addEdge('handleInvalidCategorization', 'handleValidatePipeline')
.addEdge('handleErrors', 'handleValidatePipeline')
.addEdge('handleReview', 'handleValidatePipeline')
.addConditionalEdges('handleValidatePipeline', validationRouter, {
categorization: 'handleCategorization',
validateCategorization: 'handleCategorizationValidation',
})
.addConditionalEdges('handleCategorizationValidation', chainRouter, {
modelOutput: 'modelOutput',
errors: 'handleErrors',
invalidCategorization: 'handleInvalidCategorization',
review: 'handleReview',
});
.addConditionalEdges(
'handleValidatePipeline',
(state: CategorizationState) => validationRouter({ state }),
{
categorization: 'handleCategorization',
validateCategorization: 'handleCategorizationValidation',
}
)
.addConditionalEdges(
'handleCategorizationValidation',
(state: CategorizationState) => chainRouter({ state }),
{
modelOutput: 'modelOutput',
errors: 'handleErrors',
invalidCategorization: 'handleInvalidCategorization',
review: 'handleReview',
}
);
const compiledCategorizationGraph = workflow.compile();
return compiledCategorizationGraph;

View file

@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;
describe('Testing categorization handler', () => {
it('handleInvalidCategorization()', async () => {
const response = await handleInvalidCategorization(testState, mockLlm);
const response = await handleInvalidCategorization({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);

View file

@ -4,21 +4,19 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { CategorizationNodeParams } from './types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import { combineProcessors } from '../../util/processors';
import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants';
import { CATEGORIZATION_VALIDATION_PROMPT } from './prompts';
export async function handleInvalidCategorization(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleInvalidCategorization({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationInvalidPrompt = CATEGORIZATION_VALIDATION_PROMPT;
const outputParser = new JsonOutputParser();

View file

@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(categorizationMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;
describe('Testing categorization handler', () => {
it('handleReview()', async () => {
const response = await handleReview(testState, mockLlm);
const response = await handleReview({ state, model });
expect(response.currentPipeline).toStrictEqual(
categorizationExpectedHandlerResponse.currentPipeline
);

View file

@ -4,21 +4,19 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import { CATEGORIZATION_REVIEW_PROMPT } from './prompts';
import type { Pipeline } from '../../../common';
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { CategorizationNodeParams } from './types';
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
import { combineProcessors } from '../../util/processors';
import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants';
export async function handleReview(
state: CategorizationState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleReview({
state,
model,
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
const categorizationReviewPrompt = CATEGORIZATION_REVIEW_PROMPT;
const outputParser = new JsonOutputParser();
const categorizationReview = categorizationReviewPrompt.pipe(model).pipe(outputParser);

View file

@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import type { CategorizationState, ChatModels } from '../../types';
export interface CategorizationBaseNodeParams {
state: CategorizationState;
}
export interface CategorizationNodeParams extends CategorizationBaseNodeParams {
model: ChatModels;
}
export interface CategorizationGraphParams {
model: ChatModels;
client: IScopedClusterClient;
}

View file

@ -9,12 +9,12 @@ import { handleCategorizationValidation } from './validate';
import type { CategorizationState } from '../../types';
import { categorizationTestState } from '../../../__jest__/fixtures/categorization';
const testState: CategorizationState = categorizationTestState;
const state: CategorizationState = categorizationTestState;
describe('Testing categorization invalid category', () => {
it('handleCategorizationValidation()', async () => {
testState.pipelineResults = [{ test: 'testresult', event: { category: ['foo'] } }];
const response = handleCategorizationValidation(testState);
state.pipelineResults = [{ test: 'testresult', event: { category: ['foo'] } }];
const response = handleCategorizationValidation({ state });
expect(response.invalidCategorization).toEqual([
{
error:
@ -27,8 +27,8 @@ describe('Testing categorization invalid category', () => {
describe('Testing categorization invalid type', () => {
it('handleCategorizationValidation()', async () => {
testState.pipelineResults = [{ test: 'testresult', event: { type: ['foo'] } }];
const response = handleCategorizationValidation(testState);
state.pipelineResults = [{ test: 'testresult', event: { type: ['foo'] } }];
const response = handleCategorizationValidation({ state });
expect(response.invalidCategorization).toEqual([
{
error:
@ -41,10 +41,10 @@ describe('Testing categorization invalid type', () => {
describe('Testing categorization invalid compatibility', () => {
it('handleCategorizationValidation()', async () => {
testState.pipelineResults = [
state.pipelineResults = [
{ test: 'testresult', event: { category: ['authentication'], type: ['access'] } },
];
const response = handleCategorizationValidation(testState);
const response = handleCategorizationValidation({ state });
expect(response.invalidCategorization).toEqual([
{
error: 'event.type (access) not compatible with any of the event.category (authentication)',

View file

@ -5,6 +5,7 @@
* 2.0.
*/
import type { CategorizationState } from '../../types';
import type { CategorizationBaseNodeParams } from './types';
import { ECS_EVENT_TYPES_PER_CATEGORY, EVENT_CATEGORIES, EVENT_TYPES } from './constants';
import type { EventCategories } from './constants';
@ -22,11 +23,9 @@ interface CategorizationError {
error: string;
}
export function handleCategorizationValidation(state: CategorizationState): {
previousInvalidCategorization: string;
invalidCategorization: CategorizationError[];
lastExecutedChain: string;
} {
export function handleCategorizationValidation({
state,
}: CategorizationBaseNodeParams): Partial<CategorizationState> {
let previousInvalidCategorization = '';
const errors: CategorizationError[] = [];
const pipelineResults = state.pipelineResults as PipelineResult[];

View file

@ -14,15 +14,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: '{ "message": "ll callback later."}',
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
const testState: EcsMappingState = ecsTestState;
const state: EcsMappingState = ecsTestState;
describe('Testing ecs handler', () => {
it('handleDuplicates()', async () => {
const response = await handleDuplicates(testState, mockLlm);
const response = await handleDuplicates({ state, model });
expect(response.currentMapping).toStrictEqual({ message: 'll callback later.' });
expect(response.lastExecutedChain).toBe('duplicateFields');
});

View file

@ -4,18 +4,16 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { EcsNodeParams } from './types';
import type { EcsMappingState } from '../../types';
import { ECS_DUPLICATES_PROMPT } from './prompts';
export async function handleDuplicates(
state: EcsMappingState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleDuplicates({
state,
model,
}: EcsNodeParams): Promise<Partial<EcsMappingState>> {
const outputParser = new JsonOutputParser();
const ecsDuplicatesGraph = ECS_DUPLICATES_PROMPT.pipe(model).pipe(outputParser);

View file

@ -24,7 +24,7 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: "I'll callback later.",
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
@ -69,16 +69,22 @@ describe('EcsGraph', () => {
// When getEcsGraph runs, langgraph compiles the graph it will error if the graph has any issues.
// Common issues for example detecting a node has no next step, or there is a infinite loop between them.
try {
await getEcsGraph(mockLlm);
await getEcsGraph({ model });
} catch (error) {
fail(`getEcsGraph threw an error: ${error}`);
throw Error(`getEcsGraph threw an error: ${error}`);
}
});
it('Runs the whole graph, with mocked outputs from the LLM.', async () => {
// The mocked outputs are specifically crafted to trigger ALL different conditions, allowing us to test the whole graph.
// This is why we have all the expects ensuring each function was called.
const ecsGraph = await getEcsGraph(mockLlm);
const response = await ecsGraph.invoke(mockedRequest);
const ecsGraph = await getEcsGraph({ model });
let response;
try {
response = await ecsGraph.invoke(mockedRequest);
} catch (error) {
throw Error(`getEcsGraph threw an error: ${error}`);
}
expect(response.results).toStrictEqual(ecsMappingExpectedResults);
// Check if the functions were called

View file

@ -5,12 +5,9 @@
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { END, START, StateGraph, Send } from '@langchain/langgraph';
import type { EcsMappingState } from '../../types';
import type { EcsGraphParams, EcsBaseNodeParams } from './types';
import { modelInput, modelOutput, modelSubOutput } from './model';
import { handleDuplicates } from './duplicates';
import { handleInvalidEcs } from './invalid';
@ -19,7 +16,7 @@ import { handleMissingKeys } from './missing';
import { handleValidateMappings } from './validate';
import { graphState } from './state';
const handleCreateMappingChunks = async (state: EcsMappingState) => {
const handleCreateMappingChunks = async ({ state }: EcsBaseNodeParams) => {
// Cherrypick a shallow copy of state to pass to subgraph
const stateParams = {
exAnswer: state.exAnswer,
@ -36,7 +33,7 @@ const handleCreateMappingChunks = async (state: EcsMappingState) => {
return 'modelOutput';
};
function chainRouter(state: EcsMappingState): string {
function chainRouter({ state }: EcsBaseNodeParams): string {
if (Object.keys(state.duplicateFields).length > 0) {
return 'duplicateFields';
}
@ -53,22 +50,22 @@ function chainRouter(state: EcsMappingState): string {
}
// This is added as a separate graph to be able to run these steps concurrently from handleCreateMappingChunks
async function getEcsSubGraph(model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) {
async function getEcsSubGraph({ model }: EcsGraphParams) {
const workflow = new StateGraph({
channels: graphState,
})
.addNode('modelSubOutput', modelSubOutput)
.addNode('handleValidation', handleValidateMappings)
.addNode('handleEcsMapping', (state: EcsMappingState) => handleEcsMapping(state, model))
.addNode('handleDuplicates', (state: EcsMappingState) => handleDuplicates(state, model))
.addNode('handleMissingKeys', (state: EcsMappingState) => handleMissingKeys(state, model))
.addNode('handleInvalidEcs', (state: EcsMappingState) => handleInvalidEcs(state, model))
.addNode('modelSubOutput', (state: EcsMappingState) => modelSubOutput({ state }))
.addNode('handleValidation', (state: EcsMappingState) => handleValidateMappings({ state }))
.addNode('handleEcsMapping', (state: EcsMappingState) => handleEcsMapping({ state, model }))
.addNode('handleDuplicates', (state: EcsMappingState) => handleDuplicates({ state, model }))
.addNode('handleMissingKeys', (state: EcsMappingState) => handleMissingKeys({ state, model }))
.addNode('handleInvalidEcs', (state: EcsMappingState) => handleInvalidEcs({ state, model }))
.addEdge(START, 'handleEcsMapping')
.addEdge('handleEcsMapping', 'handleValidation')
.addEdge('handleDuplicates', 'handleValidation')
.addEdge('handleMissingKeys', 'handleValidation')
.addEdge('handleInvalidEcs', 'handleValidation')
.addConditionalEdges('handleValidation', chainRouter, {
.addConditionalEdges('handleValidation', (state: EcsMappingState) => chainRouter({ state }), {
duplicateFields: 'handleDuplicates',
missingKeys: 'handleMissingKeys',
invalidEcsFields: 'handleInvalidEcs',
@ -81,17 +78,19 @@ async function getEcsSubGraph(model: ActionsClientChatOpenAI | ActionsClientSimp
return compiledEcsSubGraph;
}
export async function getEcsGraph(model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) {
const subGraph = await getEcsSubGraph(model);
export async function getEcsGraph({ model }: EcsGraphParams) {
const subGraph = await getEcsSubGraph({ model });
const workflow = new StateGraph({
channels: graphState,
})
.addNode('modelInput', modelInput)
.addNode('modelOutput', modelOutput)
.addNode('modelInput', (state: EcsMappingState) => modelInput({ state }))
.addNode('modelOutput', (state: EcsMappingState) => modelOutput({ state }))
.addNode('subGraph', subGraph)
.addEdge(START, 'modelInput')
.addEdge('subGraph', 'modelOutput')
.addConditionalEdges('modelInput', handleCreateMappingChunks)
.addConditionalEdges('modelInput', (state: EcsMappingState) =>
handleCreateMappingChunks({ state })
)
.addEdge('modelOutput', END);
const compiledEcsGraph = workflow.compile();

View file

@ -4,4 +4,5 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { getEcsGraph } from './graph';

View file

@ -14,15 +14,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: '{ "message": "ll callback later."}',
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
const testState: EcsMappingState = ecsTestState;
const state: EcsMappingState = ecsTestState;
describe('Testing ecs handlers', () => {
it('handleInvalidEcs()', async () => {
const response = await handleInvalidEcs(testState, mockLlm);
const response = await handleInvalidEcs({ state, model });
expect(response.currentMapping).toStrictEqual({ message: 'll callback later.' });
expect(response.lastExecutedChain).toBe('invalidEcs');
});

View file

@ -4,18 +4,15 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { EcsNodeParams } from './types';
import type { EcsMappingState } from '../../types';
import { ECS_INVALID_PROMPT } from './prompts';
export async function handleInvalidEcs(
state: EcsMappingState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleInvalidEcs({
state,
model,
}: EcsNodeParams): Promise<Partial<EcsMappingState>> {
const outputParser = new JsonOutputParser();
const ecsInvalidEcsGraph = ECS_INVALID_PROMPT.pipe(model).pipe(outputParser);

View file

@ -14,15 +14,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: '{ "message": "ll callback later."}',
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
const testState: EcsMappingState = ecsTestState;
const state: EcsMappingState = ecsTestState;
describe('Testing ecs handler', () => {
it('handleEcsMapping()', async () => {
const response = await handleEcsMapping(testState, mockLlm);
const response = await handleEcsMapping({ state, model });
expect(response.currentMapping).toStrictEqual({ message: 'll callback later.' });
expect(response.lastExecutedChain).toBe('ecsMapping');
});

View file

@ -4,18 +4,16 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { EcsNodeParams } from './types';
import type { EcsMappingState } from '../../types';
import { ECS_MAIN_PROMPT } from './prompts';
export async function handleEcsMapping(
state: EcsMappingState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleEcsMapping({
state,
model,
}: EcsNodeParams): Promise<Partial<EcsMappingState>> {
const outputParser = new JsonOutputParser();
const ecsMainGraph = ECS_MAIN_PROMPT.pipe(model).pipe(outputParser);

View file

@ -14,15 +14,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: '{ "message": "ll callback later."}',
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
const testState: EcsMappingState = ecsTestState;
const state: EcsMappingState = ecsTestState;
describe('Testing ecs handler', () => {
it('handleMissingKeys()', async () => {
const response = await handleMissingKeys(testState, mockLlm);
const response = await handleMissingKeys({ state, model });
expect(response.currentMapping).toStrictEqual({ message: 'll callback later.' });
expect(response.lastExecutedChain).toBe('missingKeys');
});

View file

@ -4,18 +4,16 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { EcsMappingState } from '../../types';
import { EcsNodeParams } from './types';
import { EcsMappingState } from '../../types';
import { ECS_MISSING_KEYS_PROMPT } from './prompts';
export async function handleMissingKeys(
state: EcsMappingState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleMissingKeys({
state,
model,
}: EcsNodeParams): Promise<Partial<EcsMappingState>> {
const outputParser = new JsonOutputParser();
const ecsMissingGraph = ECS_MISSING_KEYS_PROMPT.pipe(model).pipe(outputParser);

View file

@ -9,15 +9,16 @@ import { ECS_EXAMPLE_ANSWER, ECS_FIELDS } from './constants';
import { createPipeline } from './pipeline';
import { mergeAndChunkSamples } from './chunk';
import type { EcsMappingState } from '../../types';
import type { EcsBaseNodeParams } from './types';
export function modelSubOutput(state: EcsMappingState): Partial<EcsMappingState> {
export function modelSubOutput({ state }: EcsBaseNodeParams): Partial<EcsMappingState> {
return {
lastExecutedChain: 'ModelSubOutput',
finalMapping: state.currentMapping,
};
}
export function modelInput(state: EcsMappingState): Partial<EcsMappingState> {
export function modelInput({ state }: EcsBaseNodeParams): Partial<EcsMappingState> {
const prefixedSamples = prefixSamples(state);
const sampleChunks = mergeAndChunkSamples(prefixedSamples, state.chunkSize);
return {
@ -30,7 +31,7 @@ export function modelInput(state: EcsMappingState): Partial<EcsMappingState> {
};
}
export function modelOutput(state: EcsMappingState): Partial<EcsMappingState> {
export function modelOutput({ state }: EcsBaseNodeParams): Partial<EcsMappingState> {
const currentPipeline = createPipeline(state);
return {
finalized: true,

View file

@ -0,0 +1,19 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { EcsMappingState, ChatModels } from '../../types';
export interface EcsBaseNodeParams {
state: EcsMappingState;
}
export interface EcsNodeParams extends EcsBaseNodeParams {
model: ChatModels;
}
export interface EcsGraphParams {
model: ChatModels;
}

View file

@ -6,7 +6,7 @@
*/
/* eslint-disable @typescript-eslint/no-explicit-any */
import { ECS_FULL } from '../../../common/ecs';
import type { EcsMappingState } from '../../types';
import type { EcsBaseNodeParams } from './types';
import { ECS_RESERVED } from './constants';
const valueFieldKeys = new Set(['target', 'confidence', 'date_formats', 'type']);
@ -152,7 +152,7 @@ export function findInvalidEcsFields(currentMapping: AnyObject): string[] {
return results;
}
export function handleValidateMappings(state: EcsMappingState): AnyObject {
export function handleValidateMappings({ state }: EcsBaseNodeParams): AnyObject {
const missingKeys = findMissingFields(state?.combinedSamples, state?.currentMapping);
const duplicateFields = findDuplicateFields(state?.prefixedSamples, state?.currentMapping);
const invalidEcsFields = findInvalidEcsFields(state?.currentMapping);

View file

@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(relatedMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
const testState: RelatedState = relatedTestState;
const state: RelatedState = relatedTestState;
describe('Testing related handler', () => {
it('handleErrors()', async () => {
const response = await handleErrors(testState, mockLlm);
const response = await handleErrors({ state, model });
expect(response.currentPipeline).toStrictEqual(relatedExpectedHandlerResponse.currentPipeline);
expect(response.lastExecutedChain).toBe('error');
});

View file

@ -4,21 +4,19 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { RelatedState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { RelatedNodeParams } from './types';
import { combineProcessors } from '../../util/processors';
import { RELATED_ERROR_PROMPT } from './prompts';
import { COMMON_ERRORS } from './constants';
export async function handleErrors(
state: RelatedState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleErrors({
state,
model,
}: RelatedNodeParams): Promise<Partial<RelatedState>> {
const relatedErrorPrompt = RELATED_ERROR_PROMPT;
const outputParser = new JsonOutputParser();
const relatedErrorGraph = relatedErrorPrompt.pipe(model).pipe(outputParser);

View file

@ -28,7 +28,7 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: "I'll callback later.",
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
@ -41,7 +41,7 @@ jest.mock('../../util/pipeline', () => ({
}));
describe('runRelatedGraph', () => {
const mockClient = {
const client = {
asCurrentUser: {
indices: {
getMapping: jest.fn(),
@ -106,14 +106,14 @@ describe('runRelatedGraph', () => {
it('Ensures that the graph compiles', async () => {
try {
await getRelatedGraph(mockClient, mockLlm);
await getRelatedGraph({ client, model });
} catch (error) {
// noop
throw Error(`getRelatedGraph threw an error: ${error}`);
}
});
it('Runs the whole graph, with mocked outputs from the LLM.', async () => {
const relatedGraph = await getRelatedGraph(mockClient, mockLlm);
const relatedGraph = await getRelatedGraph({ client, model });
(testPipeline as jest.Mock)
.mockResolvedValueOnce(testPipelineValidResult)
@ -125,8 +125,8 @@ describe('runRelatedGraph', () => {
let response;
try {
response = await relatedGraph.invoke(mockedRequestWithPipeline);
} catch (e) {
// noop
} catch (error) {
throw Error(`getRelatedGraph threw an error: ${error}`);
}
expect(response.results).toStrictEqual(relatedExpectedResults);

View file

@ -5,14 +5,10 @@
* 2.0.
*/
import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import type { StateGraphArgs } from '@langchain/langgraph';
import { StateGraph, END, START } from '@langchain/langgraph';
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import type { RelatedState } from '../../types';
import type { RelatedGraphParams, RelatedBaseNodeParams } from './types';
import { prefixSamples, formatSamples } from '../../util/samples';
import { handleValidatePipeline } from '../../util/graph';
import { handleRelated } from './related';
@ -91,7 +87,7 @@ const graphState: StateGraphArgs<RelatedState>['channels'] = {
},
};
function modelInput(state: RelatedState): Partial<RelatedState> {
function modelInput({ state }: RelatedBaseNodeParams): Partial<RelatedState> {
const samples = prefixSamples(state);
const formattedSamples = formatSamples(samples);
const initialPipeline = JSON.parse(JSON.stringify(state.currentPipeline));
@ -107,7 +103,7 @@ function modelInput(state: RelatedState): Partial<RelatedState> {
};
}
function modelOutput(state: RelatedState): Partial<RelatedState> {
function modelOutput({ state }: RelatedBaseNodeParams): Partial<RelatedState> {
return {
finalized: true,
lastExecutedChain: 'modelOutput',
@ -118,14 +114,14 @@ function modelOutput(state: RelatedState): Partial<RelatedState> {
};
}
function inputRouter(state: RelatedState): string {
function inputRouter({ state }: RelatedBaseNodeParams): string {
if (Object.keys(state.pipelineResults).length === 0) {
return 'validatePipeline';
}
return 'related';
}
function chainRouter(state: RelatedState): string {
function chainRouter({ state }: RelatedBaseNodeParams): string {
if (Object.keys(state.currentProcessors).length === 0) {
return 'related';
}
@ -141,34 +137,35 @@ function chainRouter(state: RelatedState): string {
return END;
}
export async function getRelatedGraph(
client: IScopedClusterClient,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function getRelatedGraph({ client, model }: RelatedGraphParams) {
const workflow = new StateGraph({ channels: graphState })
.addNode('modelInput', modelInput)
.addNode('modelOutput', modelOutput)
.addNode('handleRelated', (state: RelatedState) => handleRelated(state, model))
.addNode('modelInput', (state: RelatedState) => modelInput({ state }))
.addNode('modelOutput', (state: RelatedState) => modelOutput({ state }))
.addNode('handleRelated', (state: RelatedState) => handleRelated({ state, model }))
.addNode('handleValidatePipeline', (state: RelatedState) =>
handleValidatePipeline(state, client)
handleValidatePipeline({ state, client })
)
.addNode('handleErrors', (state: RelatedState) => handleErrors(state, model))
.addNode('handleReview', (state: RelatedState) => handleReview(state, model))
.addNode('handleErrors', (state: RelatedState) => handleErrors({ state, model }))
.addNode('handleReview', (state: RelatedState) => handleReview({ state, model }))
.addEdge(START, 'modelInput')
.addEdge('modelOutput', END)
.addEdge('handleRelated', 'handleValidatePipeline')
.addEdge('handleErrors', 'handleValidatePipeline')
.addEdge('handleReview', 'handleValidatePipeline')
.addConditionalEdges('modelInput', inputRouter, {
.addConditionalEdges('modelInput', (state: RelatedState) => inputRouter({ state }), {
related: 'handleRelated',
validatePipeline: 'handleValidatePipeline',
})
.addConditionalEdges('handleValidatePipeline', chainRouter, {
related: 'handleRelated',
errors: 'handleErrors',
review: 'handleReview',
modelOutput: 'modelOutput',
});
.addConditionalEdges(
'handleValidatePipeline',
(state: RelatedState) => chainRouter({ state }),
{
related: 'handleRelated',
errors: 'handleErrors',
review: 'handleReview',
modelOutput: 'modelOutput',
}
);
const compiledRelatedGraph = workflow.compile();
return compiledRelatedGraph;

View file

@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(relatedMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
const testState: RelatedState = relatedTestState;
const state: RelatedState = relatedTestState;
describe('Testing related handler', () => {
it('handleRelated()', async () => {
const response = await handleRelated(testState, mockLlm);
const response = await handleRelated({ state, model });
expect(response.currentPipeline).toStrictEqual(relatedExpectedHandlerResponse.currentPipeline);
expect(response.lastExecutedChain).toBe('related');
});

View file

@ -4,20 +4,18 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { RelatedState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { RelatedNodeParams } from './types';
import { combineProcessors } from '../../util/processors';
import { RELATED_MAIN_PROMPT } from './prompts';
export async function handleRelated(
state: RelatedState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleRelated({
state,
model,
}: RelatedNodeParams): Promise<Partial<RelatedState>> {
const relatedMainPrompt = RELATED_MAIN_PROMPT;
const outputParser = new JsonOutputParser();
const relatedMainGraph = relatedMainPrompt.pipe(model).pipe(outputParser);

View file

@ -18,15 +18,15 @@ import {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
const mockLlm = new FakeLLM({
const model = new FakeLLM({
response: JSON.stringify(relatedMockProcessors, null, 2),
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
const testState: RelatedState = relatedTestState;
const state: RelatedState = relatedTestState;
describe('Testing related handler', () => {
it('handleReview()', async () => {
const response = await handleReview(testState, mockLlm);
const response = await handleReview({ state, model });
expect(response.currentPipeline).toStrictEqual(relatedExpectedHandlerResponse.currentPipeline);
expect(response.lastExecutedChain).toBe('review');
});

View file

@ -4,20 +4,18 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import { JsonOutputParser } from '@langchain/core/output_parsers';
import type { Pipeline } from '../../../common';
import type { RelatedState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
import type { RelatedNodeParams } from './types';
import { combineProcessors } from '../../util/processors';
import { RELATED_REVIEW_PROMPT } from './prompts';
export async function handleReview(
state: RelatedState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
export async function handleReview({
state,
model,
}: RelatedNodeParams): Promise<Partial<RelatedState>> {
const relatedReviewPrompt = RELATED_REVIEW_PROMPT;
const outputParser = new JsonOutputParser();
const relatedReviewGraph = relatedReviewPrompt.pipe(model).pipe(outputParser);

View file

@ -0,0 +1,22 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import { RelatedState, ChatModels } from '../../types';
export interface RelatedBaseNodeParams {
state: RelatedState;
}
export interface RelatedNodeParams extends RelatedBaseNodeParams {
model: ChatModels;
}
export interface RelatedGraphParams {
client: IScopedClusterClient;
model: ChatModels;
}

View file

@ -92,7 +92,7 @@ export function registerCategorizationRoutes(
],
};
const graph = await getCategorizationGraph(client, model);
const graph = await getCategorizationGraph({ client, model });
const results = await graph.invoke(parameters, options);
return res.ok({ body: CategorizationResponse.parse(results) });

View file

@ -84,7 +84,7 @@ export function registerEcsRoutes(router: IRouter<IntegrationAssistantRouteHandl
],
};
const graph = await getEcsGraph(model);
const graph = await getEcsGraph({ model });
const results = await graph.invoke(parameters, options);
return res.ok({ body: EcsMappingResponse.parse(results) });

View file

@ -84,7 +84,7 @@ export function registerRelatedRoutes(router: IRouter<IntegrationAssistantRouteH
],
};
const graph = await getRelatedGraph(client, model);
const graph = await getRelatedGraph({ client, model });
const results = await graph.invoke(parameters, options);
return res.ok({ body: RelatedResponse.parse(results) });
} catch (e) {

View file

@ -6,6 +6,12 @@
*/
import type { LicensingPluginSetup, LicensingPluginStart } from '@kbn/licensing-plugin/server';
import {
ActionsClientChatOpenAI,
ActionsClientBedrockChatModel,
ActionsClientSimpleChatModel,
ActionsClientGeminiChatModel,
} from '@kbn/langchain/server';
export interface IntegrationAssistantPluginSetup {
setIsAvailable: (isAvailable: boolean) => void;
@ -97,3 +103,9 @@ export interface RelatedState {
results: object;
lastExecutedChain: string;
}
export type ChatModels =
| ActionsClientChatOpenAI
| ActionsClientBedrockChatModel
| ActionsClientSimpleChatModel
| ActionsClientGeminiChatModel;

View file

@ -8,10 +8,15 @@ import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
import type { CategorizationState, RelatedState } from '../types';
import { testPipeline } from './pipeline';
export async function handleValidatePipeline(
state: CategorizationState | RelatedState,
client: IScopedClusterClient
): Promise<Partial<CategorizationState> | Partial<RelatedState>> {
interface HandleValidateNodeParams {
state: CategorizationState | RelatedState;
client: IScopedClusterClient;
}
export async function handleValidatePipeline({
state,
client,
}: HandleValidateNodeParams): Promise<Partial<CategorizationState> | Partial<RelatedState>> {
const previousError = JSON.stringify(state.errors, null, 2);
const results = await testPipeline(state.rawSamples, state.currentPipeline, client);
return {