mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
[Automatic Import] Prepare to support more connectors (#191278)
This PR does not add any functionality, it adds interfaces to the expected parameters from get*Graph and its graph nodes. This is so it will be much easier extend this later when we might need to add/switch types over a whole graph like we would have needed when adding more connectors. The PR touches a lot of files, but does not add/remove/change any functionality at all, and the current expected function arguments are the same, just the format is a bit different to better align with how other plugins are doing it. - [x] [Unit or functional tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html) were updated or added to match the most common scenarios.
This commit is contained in:
parent
a1e216b148
commit
791f638823
41 changed files with 307 additions and 245 deletions
|
@ -18,15 +18,15 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
const model = new FakeLLM({
|
||||
response: JSON.stringify(categorizationMockProcessors, null, 2),
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
const testState: CategorizationState = categorizationTestState;
|
||||
const state: CategorizationState = categorizationTestState;
|
||||
|
||||
describe('Testing categorization handler', () => {
|
||||
it('handleCategorization()', async () => {
|
||||
const response = await handleCategorization(testState, mockLlm);
|
||||
const response = await handleCategorization({ state, model });
|
||||
expect(response.currentPipeline).toStrictEqual(
|
||||
categorizationExpectedHandlerResponse.currentPipeline
|
||||
);
|
||||
|
|
|
@ -4,21 +4,18 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
import { JsonOutputParser } from '@langchain/core/output_parsers';
|
||||
import type { Pipeline } from '../../../common';
|
||||
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
|
||||
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
|
||||
import type { CategorizationNodeParams } from './types';
|
||||
import { combineProcessors } from '../../util/processors';
|
||||
import { CATEGORIZATION_MAIN_PROMPT } from './prompts';
|
||||
import { CATEGORIZATION_EXAMPLE_PROCESSORS } from './constants';
|
||||
|
||||
export async function handleCategorization(
|
||||
state: CategorizationState,
|
||||
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
|
||||
) {
|
||||
export async function handleCategorization({
|
||||
state,
|
||||
model,
|
||||
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
|
||||
const categorizationMainPrompt = CATEGORIZATION_MAIN_PROMPT;
|
||||
const outputParser = new JsonOutputParser();
|
||||
const categorizationMainGraph = categorizationMainPrompt.pipe(model).pipe(outputParser);
|
||||
|
|
|
@ -18,15 +18,15 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
const model = new FakeLLM({
|
||||
response: JSON.stringify(categorizationMockProcessors, null, 2),
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
const testState: CategorizationState = categorizationTestState;
|
||||
const state: CategorizationState = categorizationTestState;
|
||||
|
||||
describe('Testing categorization handler', () => {
|
||||
it('handleErrors()', async () => {
|
||||
const response = await handleErrors(testState, mockLlm);
|
||||
const response = await handleErrors({ state, model });
|
||||
expect(response.currentPipeline).toStrictEqual(
|
||||
categorizationExpectedHandlerResponse.currentPipeline
|
||||
);
|
||||
|
|
|
@ -4,20 +4,18 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
import { JsonOutputParser } from '@langchain/core/output_parsers';
|
||||
import type { Pipeline } from '../../../common';
|
||||
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
|
||||
import type { CategorizationNodeParams } from './types';
|
||||
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
|
||||
import { combineProcessors } from '../../util/processors';
|
||||
import { CATEGORIZATION_ERROR_PROMPT } from './prompts';
|
||||
|
||||
export async function handleErrors(
|
||||
state: CategorizationState,
|
||||
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
|
||||
) {
|
||||
export async function handleErrors({
|
||||
state,
|
||||
model,
|
||||
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
|
||||
const categorizationErrorPrompt = CATEGORIZATION_ERROR_PROMPT;
|
||||
|
||||
const outputParser = new JsonOutputParser();
|
||||
|
|
|
@ -31,7 +31,7 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
const model = new FakeLLM({
|
||||
response: "I'll callback later.",
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
|
@ -45,7 +45,7 @@ jest.mock('../../util/pipeline', () => ({
|
|||
}));
|
||||
|
||||
describe('runCategorizationGraph', () => {
|
||||
const mockClient = {
|
||||
const client = {
|
||||
asCurrentUser: {
|
||||
ingest: {
|
||||
simulate: jest.fn(),
|
||||
|
@ -131,14 +131,14 @@ describe('runCategorizationGraph', () => {
|
|||
|
||||
it('Ensures that the graph compiles', async () => {
|
||||
try {
|
||||
await getCategorizationGraph(mockClient, mockLlm);
|
||||
await getCategorizationGraph({ client, model });
|
||||
} catch (error) {
|
||||
// noop
|
||||
throw Error(`getCategorizationGraph threw an error: ${error}`);
|
||||
}
|
||||
});
|
||||
|
||||
it('Runs the whole graph, with mocked outputs from the LLM.', async () => {
|
||||
const categorizationGraph = await getCategorizationGraph(mockClient, mockLlm);
|
||||
const categorizationGraph = await getCategorizationGraph({ client, model });
|
||||
|
||||
(testPipeline as jest.Mock)
|
||||
.mockResolvedValueOnce(testPipelineValidResult)
|
||||
|
@ -151,8 +151,8 @@ describe('runCategorizationGraph', () => {
|
|||
let response;
|
||||
try {
|
||||
response = await categorizationGraph.invoke(mockedRequestWithPipeline);
|
||||
} catch (e) {
|
||||
// noop
|
||||
} catch (error) {
|
||||
throw Error(`getCategorizationGraph threw an error: ${error}`);
|
||||
}
|
||||
|
||||
expect(response.results).toStrictEqual(categorizationExpectedResults);
|
||||
|
|
|
@ -5,14 +5,10 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
|
||||
import type { StateGraphArgs } from '@langchain/langgraph';
|
||||
import { StateGraph, END, START } from '@langchain/langgraph';
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
import type { CategorizationState } from '../../types';
|
||||
import type { CategorizationGraphParams, CategorizationBaseNodeParams } from './types';
|
||||
import { prefixSamples, formatSamples } from '../../util/samples';
|
||||
import { handleCategorization } from './categorization';
|
||||
import { handleValidatePipeline } from '../../util/graph';
|
||||
|
@ -105,7 +101,7 @@ const graphState: StateGraphArgs<CategorizationState>['channels'] = {
|
|||
},
|
||||
};
|
||||
|
||||
function modelInput(state: CategorizationState): Partial<CategorizationState> {
|
||||
function modelInput({ state }: CategorizationBaseNodeParams): Partial<CategorizationState> {
|
||||
const samples = prefixSamples(state);
|
||||
const formattedSamples = formatSamples(samples);
|
||||
const initialPipeline = JSON.parse(JSON.stringify(state.currentPipeline));
|
||||
|
@ -122,7 +118,7 @@ function modelInput(state: CategorizationState): Partial<CategorizationState> {
|
|||
};
|
||||
}
|
||||
|
||||
function modelOutput(state: CategorizationState): Partial<CategorizationState> {
|
||||
function modelOutput({ state }: CategorizationBaseNodeParams): Partial<CategorizationState> {
|
||||
return {
|
||||
finalized: true,
|
||||
lastExecutedChain: 'modelOutput',
|
||||
|
@ -133,14 +129,14 @@ function modelOutput(state: CategorizationState): Partial<CategorizationState> {
|
|||
};
|
||||
}
|
||||
|
||||
function validationRouter(state: CategorizationState): string {
|
||||
function validationRouter({ state }: CategorizationBaseNodeParams): string {
|
||||
if (Object.keys(state.currentProcessors).length === 0) {
|
||||
return 'categorization';
|
||||
}
|
||||
return 'validateCategorization';
|
||||
}
|
||||
|
||||
function chainRouter(state: CategorizationState): string {
|
||||
function chainRouter({ state }: CategorizationBaseNodeParams): string {
|
||||
if (Object.keys(state.errors).length > 0) {
|
||||
return 'errors';
|
||||
}
|
||||
|
@ -157,27 +153,26 @@ function chainRouter(state: CategorizationState): string {
|
|||
return END;
|
||||
}
|
||||
|
||||
export async function getCategorizationGraph(
|
||||
client: IScopedClusterClient,
|
||||
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
|
||||
) {
|
||||
export async function getCategorizationGraph({ client, model }: CategorizationGraphParams) {
|
||||
const workflow = new StateGraph({
|
||||
channels: graphState,
|
||||
})
|
||||
.addNode('modelInput', modelInput)
|
||||
.addNode('modelOutput', modelOutput)
|
||||
.addNode('modelInput', (state: CategorizationState) => modelInput({ state }))
|
||||
.addNode('modelOutput', (state: CategorizationState) => modelOutput({ state }))
|
||||
.addNode('handleCategorization', (state: CategorizationState) =>
|
||||
handleCategorization(state, model)
|
||||
handleCategorization({ state, model })
|
||||
)
|
||||
.addNode('handleValidatePipeline', (state: CategorizationState) =>
|
||||
handleValidatePipeline(state, client)
|
||||
handleValidatePipeline({ state, client })
|
||||
)
|
||||
.addNode('handleCategorizationValidation', (state: CategorizationState) =>
|
||||
handleCategorizationValidation({ state })
|
||||
)
|
||||
.addNode('handleCategorizationValidation', handleCategorizationValidation)
|
||||
.addNode('handleInvalidCategorization', (state: CategorizationState) =>
|
||||
handleInvalidCategorization(state, model)
|
||||
handleInvalidCategorization({ state, model })
|
||||
)
|
||||
.addNode('handleErrors', (state: CategorizationState) => handleErrors(state, model))
|
||||
.addNode('handleReview', (state: CategorizationState) => handleReview(state, model))
|
||||
.addNode('handleErrors', (state: CategorizationState) => handleErrors({ state, model }))
|
||||
.addNode('handleReview', (state: CategorizationState) => handleReview({ state, model }))
|
||||
.addEdge(START, 'modelInput')
|
||||
.addEdge('modelOutput', END)
|
||||
.addEdge('modelInput', 'handleValidatePipeline')
|
||||
|
@ -185,16 +180,24 @@ export async function getCategorizationGraph(
|
|||
.addEdge('handleInvalidCategorization', 'handleValidatePipeline')
|
||||
.addEdge('handleErrors', 'handleValidatePipeline')
|
||||
.addEdge('handleReview', 'handleValidatePipeline')
|
||||
.addConditionalEdges('handleValidatePipeline', validationRouter, {
|
||||
categorization: 'handleCategorization',
|
||||
validateCategorization: 'handleCategorizationValidation',
|
||||
})
|
||||
.addConditionalEdges('handleCategorizationValidation', chainRouter, {
|
||||
modelOutput: 'modelOutput',
|
||||
errors: 'handleErrors',
|
||||
invalidCategorization: 'handleInvalidCategorization',
|
||||
review: 'handleReview',
|
||||
});
|
||||
.addConditionalEdges(
|
||||
'handleValidatePipeline',
|
||||
(state: CategorizationState) => validationRouter({ state }),
|
||||
{
|
||||
categorization: 'handleCategorization',
|
||||
validateCategorization: 'handleCategorizationValidation',
|
||||
}
|
||||
)
|
||||
.addConditionalEdges(
|
||||
'handleCategorizationValidation',
|
||||
(state: CategorizationState) => chainRouter({ state }),
|
||||
{
|
||||
modelOutput: 'modelOutput',
|
||||
errors: 'handleErrors',
|
||||
invalidCategorization: 'handleInvalidCategorization',
|
||||
review: 'handleReview',
|
||||
}
|
||||
);
|
||||
|
||||
const compiledCategorizationGraph = workflow.compile();
|
||||
return compiledCategorizationGraph;
|
||||
|
|
|
@ -18,15 +18,15 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
const model = new FakeLLM({
|
||||
response: JSON.stringify(categorizationMockProcessors, null, 2),
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
const testState: CategorizationState = categorizationTestState;
|
||||
const state: CategorizationState = categorizationTestState;
|
||||
|
||||
describe('Testing categorization handler', () => {
|
||||
it('handleInvalidCategorization()', async () => {
|
||||
const response = await handleInvalidCategorization(testState, mockLlm);
|
||||
const response = await handleInvalidCategorization({ state, model });
|
||||
expect(response.currentPipeline).toStrictEqual(
|
||||
categorizationExpectedHandlerResponse.currentPipeline
|
||||
);
|
||||
|
|
|
@ -4,21 +4,19 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
import { JsonOutputParser } from '@langchain/core/output_parsers';
|
||||
import type { Pipeline } from '../../../common';
|
||||
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
|
||||
import type { CategorizationNodeParams } from './types';
|
||||
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
|
||||
import { combineProcessors } from '../../util/processors';
|
||||
import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants';
|
||||
import { CATEGORIZATION_VALIDATION_PROMPT } from './prompts';
|
||||
|
||||
export async function handleInvalidCategorization(
|
||||
state: CategorizationState,
|
||||
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
|
||||
) {
|
||||
export async function handleInvalidCategorization({
|
||||
state,
|
||||
model,
|
||||
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
|
||||
const categorizationInvalidPrompt = CATEGORIZATION_VALIDATION_PROMPT;
|
||||
|
||||
const outputParser = new JsonOutputParser();
|
||||
|
|
|
@ -18,15 +18,15 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
const model = new FakeLLM({
|
||||
response: JSON.stringify(categorizationMockProcessors, null, 2),
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
const testState: CategorizationState = categorizationTestState;
|
||||
const state: CategorizationState = categorizationTestState;
|
||||
|
||||
describe('Testing categorization handler', () => {
|
||||
it('handleReview()', async () => {
|
||||
const response = await handleReview(testState, mockLlm);
|
||||
const response = await handleReview({ state, model });
|
||||
expect(response.currentPipeline).toStrictEqual(
|
||||
categorizationExpectedHandlerResponse.currentPipeline
|
||||
);
|
||||
|
|
|
@ -4,21 +4,19 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
import { JsonOutputParser } from '@langchain/core/output_parsers';
|
||||
import { CATEGORIZATION_REVIEW_PROMPT } from './prompts';
|
||||
import type { Pipeline } from '../../../common';
|
||||
import type { CategorizationState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
|
||||
import type { CategorizationNodeParams } from './types';
|
||||
import type { SimplifiedProcessors, SimplifiedProcessor, CategorizationState } from '../../types';
|
||||
import { combineProcessors } from '../../util/processors';
|
||||
import { ECS_EVENT_TYPES_PER_CATEGORY } from './constants';
|
||||
|
||||
export async function handleReview(
|
||||
state: CategorizationState,
|
||||
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
|
||||
) {
|
||||
export async function handleReview({
|
||||
state,
|
||||
model,
|
||||
}: CategorizationNodeParams): Promise<Partial<CategorizationState>> {
|
||||
const categorizationReviewPrompt = CATEGORIZATION_REVIEW_PROMPT;
|
||||
const outputParser = new JsonOutputParser();
|
||||
const categorizationReview = categorizationReviewPrompt.pipe(model).pipe(outputParser);
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
|
||||
import type { CategorizationState, ChatModels } from '../../types';
|
||||
|
||||
export interface CategorizationBaseNodeParams {
|
||||
state: CategorizationState;
|
||||
}
|
||||
|
||||
export interface CategorizationNodeParams extends CategorizationBaseNodeParams {
|
||||
model: ChatModels;
|
||||
}
|
||||
|
||||
export interface CategorizationGraphParams {
|
||||
model: ChatModels;
|
||||
client: IScopedClusterClient;
|
||||
}
|
|
@ -9,12 +9,12 @@ import { handleCategorizationValidation } from './validate';
|
|||
import type { CategorizationState } from '../../types';
|
||||
import { categorizationTestState } from '../../../__jest__/fixtures/categorization';
|
||||
|
||||
const testState: CategorizationState = categorizationTestState;
|
||||
const state: CategorizationState = categorizationTestState;
|
||||
|
||||
describe('Testing categorization invalid category', () => {
|
||||
it('handleCategorizationValidation()', async () => {
|
||||
testState.pipelineResults = [{ test: 'testresult', event: { category: ['foo'] } }];
|
||||
const response = handleCategorizationValidation(testState);
|
||||
state.pipelineResults = [{ test: 'testresult', event: { category: ['foo'] } }];
|
||||
const response = handleCategorizationValidation({ state });
|
||||
expect(response.invalidCategorization).toEqual([
|
||||
{
|
||||
error:
|
||||
|
@ -27,8 +27,8 @@ describe('Testing categorization invalid category', () => {
|
|||
|
||||
describe('Testing categorization invalid type', () => {
|
||||
it('handleCategorizationValidation()', async () => {
|
||||
testState.pipelineResults = [{ test: 'testresult', event: { type: ['foo'] } }];
|
||||
const response = handleCategorizationValidation(testState);
|
||||
state.pipelineResults = [{ test: 'testresult', event: { type: ['foo'] } }];
|
||||
const response = handleCategorizationValidation({ state });
|
||||
expect(response.invalidCategorization).toEqual([
|
||||
{
|
||||
error:
|
||||
|
@ -41,10 +41,10 @@ describe('Testing categorization invalid type', () => {
|
|||
|
||||
describe('Testing categorization invalid compatibility', () => {
|
||||
it('handleCategorizationValidation()', async () => {
|
||||
testState.pipelineResults = [
|
||||
state.pipelineResults = [
|
||||
{ test: 'testresult', event: { category: ['authentication'], type: ['access'] } },
|
||||
];
|
||||
const response = handleCategorizationValidation(testState);
|
||||
const response = handleCategorizationValidation({ state });
|
||||
expect(response.invalidCategorization).toEqual([
|
||||
{
|
||||
error: 'event.type (access) not compatible with any of the event.category (authentication)',
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
import type { CategorizationState } from '../../types';
|
||||
import type { CategorizationBaseNodeParams } from './types';
|
||||
import { ECS_EVENT_TYPES_PER_CATEGORY, EVENT_CATEGORIES, EVENT_TYPES } from './constants';
|
||||
|
||||
import type { EventCategories } from './constants';
|
||||
|
@ -22,11 +23,9 @@ interface CategorizationError {
|
|||
error: string;
|
||||
}
|
||||
|
||||
export function handleCategorizationValidation(state: CategorizationState): {
|
||||
previousInvalidCategorization: string;
|
||||
invalidCategorization: CategorizationError[];
|
||||
lastExecutedChain: string;
|
||||
} {
|
||||
export function handleCategorizationValidation({
|
||||
state,
|
||||
}: CategorizationBaseNodeParams): Partial<CategorizationState> {
|
||||
let previousInvalidCategorization = '';
|
||||
const errors: CategorizationError[] = [];
|
||||
const pipelineResults = state.pipelineResults as PipelineResult[];
|
||||
|
|
|
@ -14,15 +14,15 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
const model = new FakeLLM({
|
||||
response: '{ "message": "ll callback later."}',
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
const testState: EcsMappingState = ecsTestState;
|
||||
const state: EcsMappingState = ecsTestState;
|
||||
|
||||
describe('Testing ecs handler', () => {
|
||||
it('handleDuplicates()', async () => {
|
||||
const response = await handleDuplicates(testState, mockLlm);
|
||||
const response = await handleDuplicates({ state, model });
|
||||
expect(response.currentMapping).toStrictEqual({ message: 'll callback later.' });
|
||||
expect(response.lastExecutedChain).toBe('duplicateFields');
|
||||
});
|
||||
|
|
|
@ -4,18 +4,16 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
import { JsonOutputParser } from '@langchain/core/output_parsers';
|
||||
import type { EcsNodeParams } from './types';
|
||||
import type { EcsMappingState } from '../../types';
|
||||
import { ECS_DUPLICATES_PROMPT } from './prompts';
|
||||
|
||||
export async function handleDuplicates(
|
||||
state: EcsMappingState,
|
||||
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
|
||||
) {
|
||||
export async function handleDuplicates({
|
||||
state,
|
||||
model,
|
||||
}: EcsNodeParams): Promise<Partial<EcsMappingState>> {
|
||||
const outputParser = new JsonOutputParser();
|
||||
const ecsDuplicatesGraph = ECS_DUPLICATES_PROMPT.pipe(model).pipe(outputParser);
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
const model = new FakeLLM({
|
||||
response: "I'll callback later.",
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
|
@ -69,16 +69,22 @@ describe('EcsGraph', () => {
|
|||
// When getEcsGraph runs, langgraph compiles the graph it will error if the graph has any issues.
|
||||
// Common issues for example detecting a node has no next step, or there is a infinite loop between them.
|
||||
try {
|
||||
await getEcsGraph(mockLlm);
|
||||
await getEcsGraph({ model });
|
||||
} catch (error) {
|
||||
fail(`getEcsGraph threw an error: ${error}`);
|
||||
throw Error(`getEcsGraph threw an error: ${error}`);
|
||||
}
|
||||
});
|
||||
it('Runs the whole graph, with mocked outputs from the LLM.', async () => {
|
||||
// The mocked outputs are specifically crafted to trigger ALL different conditions, allowing us to test the whole graph.
|
||||
// This is why we have all the expects ensuring each function was called.
|
||||
const ecsGraph = await getEcsGraph(mockLlm);
|
||||
const response = await ecsGraph.invoke(mockedRequest);
|
||||
const ecsGraph = await getEcsGraph({ model });
|
||||
let response;
|
||||
try {
|
||||
response = await ecsGraph.invoke(mockedRequest);
|
||||
} catch (error) {
|
||||
throw Error(`getEcsGraph threw an error: ${error}`);
|
||||
}
|
||||
|
||||
expect(response.results).toStrictEqual(ecsMappingExpectedResults);
|
||||
|
||||
// Check if the functions were called
|
||||
|
|
|
@ -5,12 +5,9 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
import { END, START, StateGraph, Send } from '@langchain/langgraph';
|
||||
import type { EcsMappingState } from '../../types';
|
||||
import type { EcsGraphParams, EcsBaseNodeParams } from './types';
|
||||
import { modelInput, modelOutput, modelSubOutput } from './model';
|
||||
import { handleDuplicates } from './duplicates';
|
||||
import { handleInvalidEcs } from './invalid';
|
||||
|
@ -19,7 +16,7 @@ import { handleMissingKeys } from './missing';
|
|||
import { handleValidateMappings } from './validate';
|
||||
import { graphState } from './state';
|
||||
|
||||
const handleCreateMappingChunks = async (state: EcsMappingState) => {
|
||||
const handleCreateMappingChunks = async ({ state }: EcsBaseNodeParams) => {
|
||||
// Cherrypick a shallow copy of state to pass to subgraph
|
||||
const stateParams = {
|
||||
exAnswer: state.exAnswer,
|
||||
|
@ -36,7 +33,7 @@ const handleCreateMappingChunks = async (state: EcsMappingState) => {
|
|||
return 'modelOutput';
|
||||
};
|
||||
|
||||
function chainRouter(state: EcsMappingState): string {
|
||||
function chainRouter({ state }: EcsBaseNodeParams): string {
|
||||
if (Object.keys(state.duplicateFields).length > 0) {
|
||||
return 'duplicateFields';
|
||||
}
|
||||
|
@ -53,22 +50,22 @@ function chainRouter(state: EcsMappingState): string {
|
|||
}
|
||||
|
||||
// This is added as a separate graph to be able to run these steps concurrently from handleCreateMappingChunks
|
||||
async function getEcsSubGraph(model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) {
|
||||
async function getEcsSubGraph({ model }: EcsGraphParams) {
|
||||
const workflow = new StateGraph({
|
||||
channels: graphState,
|
||||
})
|
||||
.addNode('modelSubOutput', modelSubOutput)
|
||||
.addNode('handleValidation', handleValidateMappings)
|
||||
.addNode('handleEcsMapping', (state: EcsMappingState) => handleEcsMapping(state, model))
|
||||
.addNode('handleDuplicates', (state: EcsMappingState) => handleDuplicates(state, model))
|
||||
.addNode('handleMissingKeys', (state: EcsMappingState) => handleMissingKeys(state, model))
|
||||
.addNode('handleInvalidEcs', (state: EcsMappingState) => handleInvalidEcs(state, model))
|
||||
.addNode('modelSubOutput', (state: EcsMappingState) => modelSubOutput({ state }))
|
||||
.addNode('handleValidation', (state: EcsMappingState) => handleValidateMappings({ state }))
|
||||
.addNode('handleEcsMapping', (state: EcsMappingState) => handleEcsMapping({ state, model }))
|
||||
.addNode('handleDuplicates', (state: EcsMappingState) => handleDuplicates({ state, model }))
|
||||
.addNode('handleMissingKeys', (state: EcsMappingState) => handleMissingKeys({ state, model }))
|
||||
.addNode('handleInvalidEcs', (state: EcsMappingState) => handleInvalidEcs({ state, model }))
|
||||
.addEdge(START, 'handleEcsMapping')
|
||||
.addEdge('handleEcsMapping', 'handleValidation')
|
||||
.addEdge('handleDuplicates', 'handleValidation')
|
||||
.addEdge('handleMissingKeys', 'handleValidation')
|
||||
.addEdge('handleInvalidEcs', 'handleValidation')
|
||||
.addConditionalEdges('handleValidation', chainRouter, {
|
||||
.addConditionalEdges('handleValidation', (state: EcsMappingState) => chainRouter({ state }), {
|
||||
duplicateFields: 'handleDuplicates',
|
||||
missingKeys: 'handleMissingKeys',
|
||||
invalidEcsFields: 'handleInvalidEcs',
|
||||
|
@ -81,17 +78,19 @@ async function getEcsSubGraph(model: ActionsClientChatOpenAI | ActionsClientSimp
|
|||
return compiledEcsSubGraph;
|
||||
}
|
||||
|
||||
export async function getEcsGraph(model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) {
|
||||
const subGraph = await getEcsSubGraph(model);
|
||||
export async function getEcsGraph({ model }: EcsGraphParams) {
|
||||
const subGraph = await getEcsSubGraph({ model });
|
||||
const workflow = new StateGraph({
|
||||
channels: graphState,
|
||||
})
|
||||
.addNode('modelInput', modelInput)
|
||||
.addNode('modelOutput', modelOutput)
|
||||
.addNode('modelInput', (state: EcsMappingState) => modelInput({ state }))
|
||||
.addNode('modelOutput', (state: EcsMappingState) => modelOutput({ state }))
|
||||
.addNode('subGraph', subGraph)
|
||||
.addEdge(START, 'modelInput')
|
||||
.addEdge('subGraph', 'modelOutput')
|
||||
.addConditionalEdges('modelInput', handleCreateMappingChunks)
|
||||
.addConditionalEdges('modelInput', (state: EcsMappingState) =>
|
||||
handleCreateMappingChunks({ state })
|
||||
)
|
||||
.addEdge('modelOutput', END);
|
||||
|
||||
const compiledEcsGraph = workflow.compile();
|
||||
|
|
|
@ -4,4 +4,5 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export { getEcsGraph } from './graph';
|
||||
|
|
|
@ -14,15 +14,15 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
const model = new FakeLLM({
|
||||
response: '{ "message": "ll callback later."}',
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
const testState: EcsMappingState = ecsTestState;
|
||||
const state: EcsMappingState = ecsTestState;
|
||||
|
||||
describe('Testing ecs handlers', () => {
|
||||
it('handleInvalidEcs()', async () => {
|
||||
const response = await handleInvalidEcs(testState, mockLlm);
|
||||
const response = await handleInvalidEcs({ state, model });
|
||||
expect(response.currentMapping).toStrictEqual({ message: 'll callback later.' });
|
||||
expect(response.lastExecutedChain).toBe('invalidEcs');
|
||||
});
|
||||
|
|
|
@ -4,18 +4,15 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
import { JsonOutputParser } from '@langchain/core/output_parsers';
|
||||
import type { EcsNodeParams } from './types';
|
||||
import type { EcsMappingState } from '../../types';
|
||||
import { ECS_INVALID_PROMPT } from './prompts';
|
||||
|
||||
export async function handleInvalidEcs(
|
||||
state: EcsMappingState,
|
||||
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
|
||||
) {
|
||||
export async function handleInvalidEcs({
|
||||
state,
|
||||
model,
|
||||
}: EcsNodeParams): Promise<Partial<EcsMappingState>> {
|
||||
const outputParser = new JsonOutputParser();
|
||||
const ecsInvalidEcsGraph = ECS_INVALID_PROMPT.pipe(model).pipe(outputParser);
|
||||
|
||||
|
|
|
@ -14,15 +14,15 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
const model = new FakeLLM({
|
||||
response: '{ "message": "ll callback later."}',
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
const testState: EcsMappingState = ecsTestState;
|
||||
const state: EcsMappingState = ecsTestState;
|
||||
|
||||
describe('Testing ecs handler', () => {
|
||||
it('handleEcsMapping()', async () => {
|
||||
const response = await handleEcsMapping(testState, mockLlm);
|
||||
const response = await handleEcsMapping({ state, model });
|
||||
expect(response.currentMapping).toStrictEqual({ message: 'll callback later.' });
|
||||
expect(response.lastExecutedChain).toBe('ecsMapping');
|
||||
});
|
||||
|
|
|
@ -4,18 +4,16 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
import { JsonOutputParser } from '@langchain/core/output_parsers';
|
||||
import type { EcsNodeParams } from './types';
|
||||
import type { EcsMappingState } from '../../types';
|
||||
import { ECS_MAIN_PROMPT } from './prompts';
|
||||
|
||||
export async function handleEcsMapping(
|
||||
state: EcsMappingState,
|
||||
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
|
||||
) {
|
||||
export async function handleEcsMapping({
|
||||
state,
|
||||
model,
|
||||
}: EcsNodeParams): Promise<Partial<EcsMappingState>> {
|
||||
const outputParser = new JsonOutputParser();
|
||||
const ecsMainGraph = ECS_MAIN_PROMPT.pipe(model).pipe(outputParser);
|
||||
|
||||
|
|
|
@ -14,15 +14,15 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
const model = new FakeLLM({
|
||||
response: '{ "message": "ll callback later."}',
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
const testState: EcsMappingState = ecsTestState;
|
||||
const state: EcsMappingState = ecsTestState;
|
||||
|
||||
describe('Testing ecs handler', () => {
|
||||
it('handleMissingKeys()', async () => {
|
||||
const response = await handleMissingKeys(testState, mockLlm);
|
||||
const response = await handleMissingKeys({ state, model });
|
||||
expect(response.currentMapping).toStrictEqual({ message: 'll callback later.' });
|
||||
expect(response.lastExecutedChain).toBe('missingKeys');
|
||||
});
|
||||
|
|
|
@ -4,18 +4,16 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
import { JsonOutputParser } from '@langchain/core/output_parsers';
|
||||
import type { EcsMappingState } from '../../types';
|
||||
import { EcsNodeParams } from './types';
|
||||
import { EcsMappingState } from '../../types';
|
||||
import { ECS_MISSING_KEYS_PROMPT } from './prompts';
|
||||
|
||||
export async function handleMissingKeys(
|
||||
state: EcsMappingState,
|
||||
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
|
||||
) {
|
||||
export async function handleMissingKeys({
|
||||
state,
|
||||
model,
|
||||
}: EcsNodeParams): Promise<Partial<EcsMappingState>> {
|
||||
const outputParser = new JsonOutputParser();
|
||||
const ecsMissingGraph = ECS_MISSING_KEYS_PROMPT.pipe(model).pipe(outputParser);
|
||||
|
||||
|
|
|
@ -9,15 +9,16 @@ import { ECS_EXAMPLE_ANSWER, ECS_FIELDS } from './constants';
|
|||
import { createPipeline } from './pipeline';
|
||||
import { mergeAndChunkSamples } from './chunk';
|
||||
import type { EcsMappingState } from '../../types';
|
||||
import type { EcsBaseNodeParams } from './types';
|
||||
|
||||
export function modelSubOutput(state: EcsMappingState): Partial<EcsMappingState> {
|
||||
export function modelSubOutput({ state }: EcsBaseNodeParams): Partial<EcsMappingState> {
|
||||
return {
|
||||
lastExecutedChain: 'ModelSubOutput',
|
||||
finalMapping: state.currentMapping,
|
||||
};
|
||||
}
|
||||
|
||||
export function modelInput(state: EcsMappingState): Partial<EcsMappingState> {
|
||||
export function modelInput({ state }: EcsBaseNodeParams): Partial<EcsMappingState> {
|
||||
const prefixedSamples = prefixSamples(state);
|
||||
const sampleChunks = mergeAndChunkSamples(prefixedSamples, state.chunkSize);
|
||||
return {
|
||||
|
@ -30,7 +31,7 @@ export function modelInput(state: EcsMappingState): Partial<EcsMappingState> {
|
|||
};
|
||||
}
|
||||
|
||||
export function modelOutput(state: EcsMappingState): Partial<EcsMappingState> {
|
||||
export function modelOutput({ state }: EcsBaseNodeParams): Partial<EcsMappingState> {
|
||||
const currentPipeline = createPipeline(state);
|
||||
return {
|
||||
finalized: true,
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type { EcsMappingState, ChatModels } from '../../types';
|
||||
|
||||
export interface EcsBaseNodeParams {
|
||||
state: EcsMappingState;
|
||||
}
|
||||
|
||||
export interface EcsNodeParams extends EcsBaseNodeParams {
|
||||
model: ChatModels;
|
||||
}
|
||||
|
||||
export interface EcsGraphParams {
|
||||
model: ChatModels;
|
||||
}
|
|
@ -6,7 +6,7 @@
|
|||
*/
|
||||
/* eslint-disable @typescript-eslint/no-explicit-any */
|
||||
import { ECS_FULL } from '../../../common/ecs';
|
||||
import type { EcsMappingState } from '../../types';
|
||||
import type { EcsBaseNodeParams } from './types';
|
||||
import { ECS_RESERVED } from './constants';
|
||||
|
||||
const valueFieldKeys = new Set(['target', 'confidence', 'date_formats', 'type']);
|
||||
|
@ -152,7 +152,7 @@ export function findInvalidEcsFields(currentMapping: AnyObject): string[] {
|
|||
return results;
|
||||
}
|
||||
|
||||
export function handleValidateMappings(state: EcsMappingState): AnyObject {
|
||||
export function handleValidateMappings({ state }: EcsBaseNodeParams): AnyObject {
|
||||
const missingKeys = findMissingFields(state?.combinedSamples, state?.currentMapping);
|
||||
const duplicateFields = findDuplicateFields(state?.prefixedSamples, state?.currentMapping);
|
||||
const invalidEcsFields = findInvalidEcsFields(state?.currentMapping);
|
||||
|
|
|
@ -18,15 +18,15 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
const model = new FakeLLM({
|
||||
response: JSON.stringify(relatedMockProcessors, null, 2),
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
const testState: RelatedState = relatedTestState;
|
||||
const state: RelatedState = relatedTestState;
|
||||
|
||||
describe('Testing related handler', () => {
|
||||
it('handleErrors()', async () => {
|
||||
const response = await handleErrors(testState, mockLlm);
|
||||
const response = await handleErrors({ state, model });
|
||||
expect(response.currentPipeline).toStrictEqual(relatedExpectedHandlerResponse.currentPipeline);
|
||||
expect(response.lastExecutedChain).toBe('error');
|
||||
});
|
||||
|
|
|
@ -4,21 +4,19 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
import { JsonOutputParser } from '@langchain/core/output_parsers';
|
||||
import type { Pipeline } from '../../../common';
|
||||
import type { RelatedState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
|
||||
import type { RelatedNodeParams } from './types';
|
||||
import { combineProcessors } from '../../util/processors';
|
||||
import { RELATED_ERROR_PROMPT } from './prompts';
|
||||
import { COMMON_ERRORS } from './constants';
|
||||
|
||||
export async function handleErrors(
|
||||
state: RelatedState,
|
||||
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
|
||||
) {
|
||||
export async function handleErrors({
|
||||
state,
|
||||
model,
|
||||
}: RelatedNodeParams): Promise<Partial<RelatedState>> {
|
||||
const relatedErrorPrompt = RELATED_ERROR_PROMPT;
|
||||
const outputParser = new JsonOutputParser();
|
||||
const relatedErrorGraph = relatedErrorPrompt.pipe(model).pipe(outputParser);
|
||||
|
|
|
@ -28,7 +28,7 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
const model = new FakeLLM({
|
||||
response: "I'll callback later.",
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
|
@ -41,7 +41,7 @@ jest.mock('../../util/pipeline', () => ({
|
|||
}));
|
||||
|
||||
describe('runRelatedGraph', () => {
|
||||
const mockClient = {
|
||||
const client = {
|
||||
asCurrentUser: {
|
||||
indices: {
|
||||
getMapping: jest.fn(),
|
||||
|
@ -106,14 +106,14 @@ describe('runRelatedGraph', () => {
|
|||
|
||||
it('Ensures that the graph compiles', async () => {
|
||||
try {
|
||||
await getRelatedGraph(mockClient, mockLlm);
|
||||
await getRelatedGraph({ client, model });
|
||||
} catch (error) {
|
||||
// noop
|
||||
throw Error(`getRelatedGraph threw an error: ${error}`);
|
||||
}
|
||||
});
|
||||
|
||||
it('Runs the whole graph, with mocked outputs from the LLM.', async () => {
|
||||
const relatedGraph = await getRelatedGraph(mockClient, mockLlm);
|
||||
const relatedGraph = await getRelatedGraph({ client, model });
|
||||
|
||||
(testPipeline as jest.Mock)
|
||||
.mockResolvedValueOnce(testPipelineValidResult)
|
||||
|
@ -125,8 +125,8 @@ describe('runRelatedGraph', () => {
|
|||
let response;
|
||||
try {
|
||||
response = await relatedGraph.invoke(mockedRequestWithPipeline);
|
||||
} catch (e) {
|
||||
// noop
|
||||
} catch (error) {
|
||||
throw Error(`getRelatedGraph threw an error: ${error}`);
|
||||
}
|
||||
|
||||
expect(response.results).toStrictEqual(relatedExpectedResults);
|
||||
|
|
|
@ -5,14 +5,10 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
|
||||
import type { StateGraphArgs } from '@langchain/langgraph';
|
||||
import { StateGraph, END, START } from '@langchain/langgraph';
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
import type { RelatedState } from '../../types';
|
||||
import type { RelatedGraphParams, RelatedBaseNodeParams } from './types';
|
||||
import { prefixSamples, formatSamples } from '../../util/samples';
|
||||
import { handleValidatePipeline } from '../../util/graph';
|
||||
import { handleRelated } from './related';
|
||||
|
@ -91,7 +87,7 @@ const graphState: StateGraphArgs<RelatedState>['channels'] = {
|
|||
},
|
||||
};
|
||||
|
||||
function modelInput(state: RelatedState): Partial<RelatedState> {
|
||||
function modelInput({ state }: RelatedBaseNodeParams): Partial<RelatedState> {
|
||||
const samples = prefixSamples(state);
|
||||
const formattedSamples = formatSamples(samples);
|
||||
const initialPipeline = JSON.parse(JSON.stringify(state.currentPipeline));
|
||||
|
@ -107,7 +103,7 @@ function modelInput(state: RelatedState): Partial<RelatedState> {
|
|||
};
|
||||
}
|
||||
|
||||
function modelOutput(state: RelatedState): Partial<RelatedState> {
|
||||
function modelOutput({ state }: RelatedBaseNodeParams): Partial<RelatedState> {
|
||||
return {
|
||||
finalized: true,
|
||||
lastExecutedChain: 'modelOutput',
|
||||
|
@ -118,14 +114,14 @@ function modelOutput(state: RelatedState): Partial<RelatedState> {
|
|||
};
|
||||
}
|
||||
|
||||
function inputRouter(state: RelatedState): string {
|
||||
function inputRouter({ state }: RelatedBaseNodeParams): string {
|
||||
if (Object.keys(state.pipelineResults).length === 0) {
|
||||
return 'validatePipeline';
|
||||
}
|
||||
return 'related';
|
||||
}
|
||||
|
||||
function chainRouter(state: RelatedState): string {
|
||||
function chainRouter({ state }: RelatedBaseNodeParams): string {
|
||||
if (Object.keys(state.currentProcessors).length === 0) {
|
||||
return 'related';
|
||||
}
|
||||
|
@ -141,34 +137,35 @@ function chainRouter(state: RelatedState): string {
|
|||
return END;
|
||||
}
|
||||
|
||||
export async function getRelatedGraph(
|
||||
client: IScopedClusterClient,
|
||||
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
|
||||
) {
|
||||
export async function getRelatedGraph({ client, model }: RelatedGraphParams) {
|
||||
const workflow = new StateGraph({ channels: graphState })
|
||||
.addNode('modelInput', modelInput)
|
||||
.addNode('modelOutput', modelOutput)
|
||||
.addNode('handleRelated', (state: RelatedState) => handleRelated(state, model))
|
||||
.addNode('modelInput', (state: RelatedState) => modelInput({ state }))
|
||||
.addNode('modelOutput', (state: RelatedState) => modelOutput({ state }))
|
||||
.addNode('handleRelated', (state: RelatedState) => handleRelated({ state, model }))
|
||||
.addNode('handleValidatePipeline', (state: RelatedState) =>
|
||||
handleValidatePipeline(state, client)
|
||||
handleValidatePipeline({ state, client })
|
||||
)
|
||||
.addNode('handleErrors', (state: RelatedState) => handleErrors(state, model))
|
||||
.addNode('handleReview', (state: RelatedState) => handleReview(state, model))
|
||||
.addNode('handleErrors', (state: RelatedState) => handleErrors({ state, model }))
|
||||
.addNode('handleReview', (state: RelatedState) => handleReview({ state, model }))
|
||||
.addEdge(START, 'modelInput')
|
||||
.addEdge('modelOutput', END)
|
||||
.addEdge('handleRelated', 'handleValidatePipeline')
|
||||
.addEdge('handleErrors', 'handleValidatePipeline')
|
||||
.addEdge('handleReview', 'handleValidatePipeline')
|
||||
.addConditionalEdges('modelInput', inputRouter, {
|
||||
.addConditionalEdges('modelInput', (state: RelatedState) => inputRouter({ state }), {
|
||||
related: 'handleRelated',
|
||||
validatePipeline: 'handleValidatePipeline',
|
||||
})
|
||||
.addConditionalEdges('handleValidatePipeline', chainRouter, {
|
||||
related: 'handleRelated',
|
||||
errors: 'handleErrors',
|
||||
review: 'handleReview',
|
||||
modelOutput: 'modelOutput',
|
||||
});
|
||||
.addConditionalEdges(
|
||||
'handleValidatePipeline',
|
||||
(state: RelatedState) => chainRouter({ state }),
|
||||
{
|
||||
related: 'handleRelated',
|
||||
errors: 'handleErrors',
|
||||
review: 'handleReview',
|
||||
modelOutput: 'modelOutput',
|
||||
}
|
||||
);
|
||||
|
||||
const compiledRelatedGraph = workflow.compile();
|
||||
return compiledRelatedGraph;
|
||||
|
|
|
@ -18,15 +18,15 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
const model = new FakeLLM({
|
||||
response: JSON.stringify(relatedMockProcessors, null, 2),
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
const testState: RelatedState = relatedTestState;
|
||||
const state: RelatedState = relatedTestState;
|
||||
|
||||
describe('Testing related handler', () => {
|
||||
it('handleRelated()', async () => {
|
||||
const response = await handleRelated(testState, mockLlm);
|
||||
const response = await handleRelated({ state, model });
|
||||
expect(response.currentPipeline).toStrictEqual(relatedExpectedHandlerResponse.currentPipeline);
|
||||
expect(response.lastExecutedChain).toBe('related');
|
||||
});
|
||||
|
|
|
@ -4,20 +4,18 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
import { JsonOutputParser } from '@langchain/core/output_parsers';
|
||||
import type { Pipeline } from '../../../common';
|
||||
import type { RelatedState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
|
||||
import type { RelatedNodeParams } from './types';
|
||||
import { combineProcessors } from '../../util/processors';
|
||||
import { RELATED_MAIN_PROMPT } from './prompts';
|
||||
|
||||
export async function handleRelated(
|
||||
state: RelatedState,
|
||||
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
|
||||
) {
|
||||
export async function handleRelated({
|
||||
state,
|
||||
model,
|
||||
}: RelatedNodeParams): Promise<Partial<RelatedState>> {
|
||||
const relatedMainPrompt = RELATED_MAIN_PROMPT;
|
||||
const outputParser = new JsonOutputParser();
|
||||
const relatedMainGraph = relatedMainPrompt.pipe(model).pipe(outputParser);
|
||||
|
|
|
@ -18,15 +18,15 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
const mockLlm = new FakeLLM({
|
||||
const model = new FakeLLM({
|
||||
response: JSON.stringify(relatedMockProcessors, null, 2),
|
||||
}) as unknown as ActionsClientChatOpenAI | ActionsClientSimpleChatModel;
|
||||
|
||||
const testState: RelatedState = relatedTestState;
|
||||
const state: RelatedState = relatedTestState;
|
||||
|
||||
describe('Testing related handler', () => {
|
||||
it('handleReview()', async () => {
|
||||
const response = await handleReview(testState, mockLlm);
|
||||
const response = await handleReview({ state, model });
|
||||
expect(response.currentPipeline).toStrictEqual(relatedExpectedHandlerResponse.currentPipeline);
|
||||
expect(response.lastExecutedChain).toBe('review');
|
||||
});
|
||||
|
|
|
@ -4,20 +4,18 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
|
||||
import { JsonOutputParser } from '@langchain/core/output_parsers';
|
||||
import type { Pipeline } from '../../../common';
|
||||
import type { RelatedState, SimplifiedProcessors, SimplifiedProcessor } from '../../types';
|
||||
import type { RelatedNodeParams } from './types';
|
||||
import { combineProcessors } from '../../util/processors';
|
||||
import { RELATED_REVIEW_PROMPT } from './prompts';
|
||||
|
||||
export async function handleReview(
|
||||
state: RelatedState,
|
||||
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
|
||||
) {
|
||||
export async function handleReview({
|
||||
state,
|
||||
model,
|
||||
}: RelatedNodeParams): Promise<Partial<RelatedState>> {
|
||||
const relatedReviewPrompt = RELATED_REVIEW_PROMPT;
|
||||
const outputParser = new JsonOutputParser();
|
||||
const relatedReviewGraph = relatedReviewPrompt.pipe(model).pipe(outputParser);
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
|
||||
import { RelatedState, ChatModels } from '../../types';
|
||||
|
||||
export interface RelatedBaseNodeParams {
|
||||
state: RelatedState;
|
||||
}
|
||||
|
||||
export interface RelatedNodeParams extends RelatedBaseNodeParams {
|
||||
model: ChatModels;
|
||||
}
|
||||
|
||||
export interface RelatedGraphParams {
|
||||
client: IScopedClusterClient;
|
||||
model: ChatModels;
|
||||
}
|
|
@ -92,7 +92,7 @@ export function registerCategorizationRoutes(
|
|||
],
|
||||
};
|
||||
|
||||
const graph = await getCategorizationGraph(client, model);
|
||||
const graph = await getCategorizationGraph({ client, model });
|
||||
const results = await graph.invoke(parameters, options);
|
||||
|
||||
return res.ok({ body: CategorizationResponse.parse(results) });
|
||||
|
|
|
@ -84,7 +84,7 @@ export function registerEcsRoutes(router: IRouter<IntegrationAssistantRouteHandl
|
|||
],
|
||||
};
|
||||
|
||||
const graph = await getEcsGraph(model);
|
||||
const graph = await getEcsGraph({ model });
|
||||
const results = await graph.invoke(parameters, options);
|
||||
|
||||
return res.ok({ body: EcsMappingResponse.parse(results) });
|
||||
|
|
|
@ -84,7 +84,7 @@ export function registerRelatedRoutes(router: IRouter<IntegrationAssistantRouteH
|
|||
],
|
||||
};
|
||||
|
||||
const graph = await getRelatedGraph(client, model);
|
||||
const graph = await getRelatedGraph({ client, model });
|
||||
const results = await graph.invoke(parameters, options);
|
||||
return res.ok({ body: RelatedResponse.parse(results) });
|
||||
} catch (e) {
|
||||
|
|
|
@ -6,6 +6,12 @@
|
|||
*/
|
||||
|
||||
import type { LicensingPluginSetup, LicensingPluginStart } from '@kbn/licensing-plugin/server';
|
||||
import {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientBedrockChatModel,
|
||||
ActionsClientSimpleChatModel,
|
||||
ActionsClientGeminiChatModel,
|
||||
} from '@kbn/langchain/server';
|
||||
|
||||
export interface IntegrationAssistantPluginSetup {
|
||||
setIsAvailable: (isAvailable: boolean) => void;
|
||||
|
@ -97,3 +103,9 @@ export interface RelatedState {
|
|||
results: object;
|
||||
lastExecutedChain: string;
|
||||
}
|
||||
|
||||
export type ChatModels =
|
||||
| ActionsClientChatOpenAI
|
||||
| ActionsClientBedrockChatModel
|
||||
| ActionsClientSimpleChatModel
|
||||
| ActionsClientGeminiChatModel;
|
||||
|
|
|
@ -8,10 +8,15 @@ import type { IScopedClusterClient } from '@kbn/core-elasticsearch-server';
|
|||
import type { CategorizationState, RelatedState } from '../types';
|
||||
import { testPipeline } from './pipeline';
|
||||
|
||||
export async function handleValidatePipeline(
|
||||
state: CategorizationState | RelatedState,
|
||||
client: IScopedClusterClient
|
||||
): Promise<Partial<CategorizationState> | Partial<RelatedState>> {
|
||||
interface HandleValidateNodeParams {
|
||||
state: CategorizationState | RelatedState;
|
||||
client: IScopedClusterClient;
|
||||
}
|
||||
|
||||
export async function handleValidatePipeline({
|
||||
state,
|
||||
client,
|
||||
}: HandleValidateNodeParams): Promise<Partial<CategorizationState> | Partial<RelatedState>> {
|
||||
const previousError = JSON.stringify(state.errors, null, 2);
|
||||
const results = await testPipeline(state.rawSamples, state.currentPipeline, client);
|
||||
return {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue