mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 09:48:58 -04:00
# Backport This will backport the following commits from `main` to `8.x`: - [Update langchain (main) #205553 ](https://github.com/elastic/kibana/pull/205553) <!--- Backport version: 9.6.6 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sorenlouv/backport) Co-authored-by: elastic-renovate-prod[bot] <174716857+elastic-renovate-prod[bot]@users.noreply.github.com>
This commit is contained in:
parent
4ee1e8756c
commit
22fc6d8bea
23 changed files with 1004 additions and 1042 deletions
47
package.json
47
package.json
|
@ -79,8 +79,8 @@
|
|||
"resolutions": {
|
||||
"**/@bazel/typescript/protobufjs": "6.11.4",
|
||||
"**/@hello-pangea/dnd": "16.6.0",
|
||||
"**/@langchain/core": "^0.3.16",
|
||||
"**/@langchain/google-common": "^0.1.1",
|
||||
"**/@langchain/core": "^0.3.40",
|
||||
"**/@langchain/google-common": "^0.1.8",
|
||||
"**/@types/node": "20.10.5",
|
||||
"**/@typescript-eslint/utils": "5.62.0",
|
||||
"**/chokidar": "^3.5.3",
|
||||
|
@ -88,10 +88,14 @@
|
|||
"**/globule/minimatch": "^3.1.2",
|
||||
"**/hoist-non-react-statics": "^3.3.2",
|
||||
"**/isomorphic-fetch/node-fetch": "^2.6.7",
|
||||
"**/langchain": "^0.3.5",
|
||||
"**/langchain": "^0.3.15",
|
||||
"**/remark-parse/trim": "1.0.1",
|
||||
"**/sharp": "0.32.6",
|
||||
"**/typescript": "5.1.6",
|
||||
"@aws-sdk/client-bedrock-agent-runtime": "^3.744.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.744.0",
|
||||
"@aws-sdk/client-kendra": "3.744.0",
|
||||
"@aws-sdk/credential-provider-node": "3.744.0",
|
||||
"@storybook/react-docgen-typescript-plugin": "1.0.6--canary.9.cd77847.0",
|
||||
"@types/react": "~18.2.0",
|
||||
"@types/react-dom": "~18.2.0",
|
||||
|
@ -102,7 +106,7 @@
|
|||
"@appland/sql-parser": "^1.5.1",
|
||||
"@aws-crypto/sha256-js": "^5.2.0",
|
||||
"@aws-crypto/util": "^5.2.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.687.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.744.0",
|
||||
"@babel/runtime": "^7.24.7",
|
||||
"@dagrejs/dagre": "^1.1.4",
|
||||
"@dnd-kit/core": "^6.1.0",
|
||||
|
@ -141,7 +145,7 @@
|
|||
"@formatjs/intl-relativetimeformat": "^11.2.12",
|
||||
"@formatjs/intl-utils": "^3.8.4",
|
||||
"@formatjs/ts-transformer": "^3.13.14",
|
||||
"@google/generative-ai": "^0.7.0",
|
||||
"@google/generative-ai": "^0.21.0",
|
||||
"@grpc/grpc-js": "^1.8.22",
|
||||
"@hapi/accept": "^6.0.3",
|
||||
"@hapi/boom": "^10.0.1",
|
||||
|
@ -1023,14 +1027,14 @@
|
|||
"@kbn/xstate-utils": "link:src/platform/packages/shared/kbn-xstate-utils",
|
||||
"@kbn/zod": "link:src/platform/packages/shared/kbn-zod",
|
||||
"@kbn/zod-helpers": "link:src/platform/packages/shared/kbn-zod-helpers",
|
||||
"@langchain/aws": "^0.1.2",
|
||||
"@langchain/community": "0.3.14",
|
||||
"@langchain/core": "^0.3.16",
|
||||
"@langchain/google-common": "^0.1.1",
|
||||
"@langchain/google-genai": "^0.1.2",
|
||||
"@langchain/google-vertexai": "^0.1.0",
|
||||
"@langchain/langgraph": "0.2.19",
|
||||
"@langchain/openai": "^0.3.11",
|
||||
"@langchain/aws": "^0.1.3",
|
||||
"@langchain/community": "^0.3.29",
|
||||
"@langchain/core": "^0.3.40",
|
||||
"@langchain/google-common": "^0.1.8",
|
||||
"@langchain/google-genai": "^0.1.8",
|
||||
"@langchain/google-vertexai": "^0.1.8",
|
||||
"@langchain/langgraph": "^0.2.45",
|
||||
"@langchain/openai": "^0.4.4",
|
||||
"@langtrase/trace-attributes": "^7.5.0",
|
||||
"@launchdarkly/node-server-sdk": "^9.7.2",
|
||||
"@launchdarkly/openfeature-node-server": "^1.0.0",
|
||||
|
@ -1057,14 +1061,11 @@
|
|||
"@paralleldrive/cuid2": "^2.2.2",
|
||||
"@reduxjs/toolkit": "1.9.7",
|
||||
"@slack/webhook": "^7.0.1",
|
||||
"@smithy/eventstream-codec": "^3.1.10",
|
||||
"@smithy/eventstream-serde-node": "^3.0.12",
|
||||
"@smithy/middleware-stack": "^3.0.11",
|
||||
"@smithy/node-http-handler": "^3.3.3",
|
||||
"@smithy/protocol-http": "^4.1.7",
|
||||
"@smithy/signature-v4": "^4.2.3",
|
||||
"@smithy/types": "^3.7.1",
|
||||
"@smithy/util-utf8": "^3.0.0",
|
||||
"@smithy/eventstream-codec": "^4.0.1",
|
||||
"@smithy/eventstream-serde-node": "^4.0.1",
|
||||
"@smithy/middleware-stack": "^4.0.1",
|
||||
"@smithy/types": "^4.1.0",
|
||||
"@smithy/util-utf8": "^4.0.0",
|
||||
"@tanstack/react-query": "^4.29.12",
|
||||
"@tanstack/react-query-devtools": "^4.29.12",
|
||||
"@turf/along": "6.0.1",
|
||||
|
@ -1176,8 +1177,8 @@
|
|||
"jsonwebtoken": "^9.0.2",
|
||||
"jsts": "^1.6.2",
|
||||
"kea": "^2.6.0",
|
||||
"langchain": "^0.3.5",
|
||||
"langsmith": "^0.2.5",
|
||||
"langchain": "^0.3.15",
|
||||
"langsmith": "^0.3.7",
|
||||
"launchdarkly-js-client-sdk": "^3.5.0",
|
||||
"load-json-file": "^6.2.0",
|
||||
"lodash": "^4.17.21",
|
||||
|
|
|
@ -703,4 +703,4 @@
|
|||
"datasourceTemplate": "docker"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
|
@ -90,12 +90,6 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
|
|||
streaming,
|
||||
// matters only for the LangSmith logs (Metadata > Invocation Params), which are misleading if this is not set
|
||||
modelName: model ?? DEFAULT_OPEN_AI_MODEL,
|
||||
// these have to be initialized, but are not actually used since we override the openai client with the actions client
|
||||
azureOpenAIApiKey: 'nothing',
|
||||
azureOpenAIApiDeploymentName: 'nothing',
|
||||
azureOpenAIApiInstanceName: 'nothing',
|
||||
azureOpenAIBasePath: 'nothing',
|
||||
azureOpenAIApiVersion: 'nothing',
|
||||
openAIApiKey: '',
|
||||
});
|
||||
this.#actionsClient = actionsClient;
|
||||
|
|
|
@ -140,6 +140,7 @@ const callOptions = {
|
|||
const handleLLMNewToken = jest.fn();
|
||||
const callRunManager = {
|
||||
handleLLMNewToken,
|
||||
handleCustomEvent: jest.fn().mockResolvedValue({}),
|
||||
} as unknown as CallbackManagerForLLMRun;
|
||||
const onFailedAttempt = jest.fn();
|
||||
const defaultArgs = {
|
||||
|
@ -149,6 +150,7 @@ const defaultArgs = {
|
|||
streaming: false,
|
||||
maxRetries: 0,
|
||||
onFailedAttempt,
|
||||
convertSystemMessageToHumanContent: false,
|
||||
};
|
||||
|
||||
const testMessage = 'Yes, your name is Andrew. How can I assist you further, Andrew?';
|
||||
|
@ -188,7 +190,6 @@ describe('ActionsClientChatVertexAI', () => {
|
|||
describe('_generate streaming: false', () => {
|
||||
it('returns the expected content when _generate is invoked', async () => {
|
||||
const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);
|
||||
|
||||
const result = await actionsClientChatVertexAI._generate(
|
||||
callMessages,
|
||||
callOptions,
|
||||
|
@ -220,7 +221,7 @@ describe('ActionsClientChatVertexAI', () => {
|
|||
expect(onFailedAttempt).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('rejects with the expected error the message has invalid content', async () => {
|
||||
it('resolves to expected result when message has invalid content', async () => {
|
||||
actionsClient.execute.mockImplementation(
|
||||
jest.fn().mockResolvedValue({
|
||||
data: {
|
||||
|
@ -235,7 +236,7 @@ describe('ActionsClientChatVertexAI', () => {
|
|||
|
||||
await expect(
|
||||
actionsClientChatVertexAI._generate(callMessages, callOptions, callRunManager)
|
||||
).rejects.toThrowError("Cannot read properties of undefined (reading 'text')");
|
||||
).resolves.toEqual({ generations: [], llmOutput: {} });
|
||||
});
|
||||
});
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ import { Readable } from 'stream';
|
|||
import { Logger } from '@kbn/logging';
|
||||
import { BaseChatModelParams } from '@langchain/core/language_models/chat_models';
|
||||
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
|
||||
import { GeminiPartText } from '@langchain/google-common/dist/types';
|
||||
import { GeminiPartText, GeminiRequest } from '@langchain/google-common/dist/types';
|
||||
import type { TelemetryMetadata } from '@kbn/actions-plugin/server/lib';
|
||||
import {
|
||||
convertResponseBadFinishReasonToErrorMsg,
|
||||
|
@ -36,6 +36,7 @@ export interface CustomChatModelInput extends BaseChatModelParams {
|
|||
model?: string;
|
||||
maxTokens?: number;
|
||||
telemetryMetadata?: TelemetryMetadata;
|
||||
convertSystemMessageToHumanContent?: boolean;
|
||||
}
|
||||
|
||||
export class ActionsClientChatVertexAI extends ChatVertexAI {
|
||||
|
@ -56,7 +57,9 @@ export class ActionsClientChatVertexAI extends ChatVertexAI {
|
|||
this.#model = props.model;
|
||||
this.#actionsClient = actionsClient;
|
||||
this.#connectorId = connectorId;
|
||||
const client = this.buildClient(props);
|
||||
// apiKey is required to build the client but is overridden by the actionsClient
|
||||
const client = this.buildClient({ apiKey: 'nothing', ...props });
|
||||
|
||||
this.connection = new ActionsClientChatConnection(
|
||||
{
|
||||
...this,
|
||||
|
@ -80,7 +83,7 @@ export class ActionsClientChatVertexAI extends ChatVertexAI {
|
|||
runManager?: CallbackManagerForLLMRun
|
||||
): AsyncGenerator<ChatGenerationChunk> {
|
||||
const parameters = this.invocationParams(options);
|
||||
const data = await this.connection.formatData(messages, parameters);
|
||||
const data = (await this.connection.formatData(messages, parameters)) as GeminiRequest;
|
||||
const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => {
|
||||
const systemPart: GeminiPartText | undefined = data?.systemInstruction
|
||||
?.parts?.[0] as unknown as GeminiPartText;
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
import {
|
||||
ChatConnection,
|
||||
GeminiContent,
|
||||
GeminiRequest,
|
||||
GoogleAbstractedClient,
|
||||
GoogleAIBaseLLMInput,
|
||||
GoogleLLMResponse,
|
||||
|
@ -46,7 +47,7 @@ export class ActionsClientChatConnection<Auth> extends ChatConnection<Auth> {
|
|||
this.temperature = fields.temperature ?? 0;
|
||||
const nativeFormatData = this.formatData.bind(this);
|
||||
this.formatData = async (data, options) => {
|
||||
const result = await nativeFormatData(data, options);
|
||||
const result = (await nativeFormatData(data, options)) as GeminiRequest;
|
||||
if (result?.contents != null && result?.contents.length) {
|
||||
// ensure there are not 2 messages in a row from the same role,
|
||||
// if there are combine them
|
||||
|
|
|
@ -6,8 +6,7 @@
|
|||
*/
|
||||
|
||||
import { BaseCallbackHandlerInput } from '@langchain/core/callbacks/base';
|
||||
import type { Run } from 'langsmith/schemas';
|
||||
import { BaseTracer } from '@langchain/core/tracers/base';
|
||||
import { BaseTracer, Run } from '@langchain/core/tracers/base';
|
||||
import agent from 'elastic-apm-node';
|
||||
import type { Logger } from '@kbn/core/server';
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
*/
|
||||
|
||||
import { BaseCallbackHandlerInput } from '@langchain/core/callbacks/base';
|
||||
import type { Run } from 'langsmith/schemas';
|
||||
import { Run } from 'langsmith/schemas';
|
||||
import { BaseTracer } from '@langchain/core/tracers/base';
|
||||
import { AnalyticsServiceSetup, Logger } from '@kbn/core/server';
|
||||
|
||||
|
|
|
@ -117,7 +117,7 @@ describe('geminiAdapter', () => {
|
|||
name: 'myFunction',
|
||||
parameters: {
|
||||
properties: {},
|
||||
type: 'OBJECT',
|
||||
type: 'object',
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -128,11 +128,11 @@ describe('geminiAdapter', () => {
|
|||
foo: {
|
||||
description: 'foo',
|
||||
enum: undefined,
|
||||
type: 'STRING',
|
||||
type: 'string',
|
||||
},
|
||||
},
|
||||
required: ['foo'],
|
||||
type: 'OBJECT',
|
||||
type: 'object',
|
||||
},
|
||||
},
|
||||
],
|
||||
|
|
|
@ -110,7 +110,7 @@ function toolsToGemini(tools: ToolOptions['tools']): Gemini.Tool[] {
|
|||
parameters: schema
|
||||
? toolSchemaToGemini({ schema })
|
||||
: {
|
||||
type: Gemini.FunctionDeclarationSchemaType.OBJECT,
|
||||
type: Gemini.SchemaType.OBJECT,
|
||||
properties: {},
|
||||
},
|
||||
};
|
||||
|
@ -130,13 +130,13 @@ function toolSchemaToGemini({ schema }: { schema: ToolSchema }): Gemini.Function
|
|||
switch (def.type) {
|
||||
case 'array':
|
||||
return {
|
||||
type: Gemini.FunctionDeclarationSchemaType.ARRAY,
|
||||
type: Gemini.SchemaType.ARRAY,
|
||||
description: def.description,
|
||||
items: convertSchemaType({ def: def.items }) as Gemini.FunctionDeclarationSchema,
|
||||
};
|
||||
case 'object':
|
||||
return {
|
||||
type: Gemini.FunctionDeclarationSchemaType.OBJECT,
|
||||
type: Gemini.SchemaType.OBJECT,
|
||||
description: def.description,
|
||||
required: def.required as string[],
|
||||
properties: def.properties
|
||||
|
@ -152,19 +152,19 @@ function toolSchemaToGemini({ schema }: { schema: ToolSchema }): Gemini.Function
|
|||
};
|
||||
case 'string':
|
||||
return {
|
||||
type: Gemini.FunctionDeclarationSchemaType.STRING,
|
||||
type: Gemini.SchemaType.STRING,
|
||||
description: def.description,
|
||||
enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined,
|
||||
};
|
||||
case 'boolean':
|
||||
return {
|
||||
type: Gemini.FunctionDeclarationSchemaType.BOOLEAN,
|
||||
type: Gemini.SchemaType.BOOLEAN,
|
||||
description: def.description,
|
||||
enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined,
|
||||
};
|
||||
case 'number':
|
||||
return {
|
||||
type: Gemini.FunctionDeclarationSchemaType.NUMBER,
|
||||
type: Gemini.SchemaType.NUMBER,
|
||||
description: def.description,
|
||||
enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined,
|
||||
};
|
||||
|
@ -172,7 +172,7 @@ function toolSchemaToGemini({ schema }: { schema: ToolSchema }): Gemini.Function
|
|||
};
|
||||
|
||||
return {
|
||||
type: Gemini.FunctionDeclarationSchemaType.OBJECT,
|
||||
type: Gemini.SchemaType.OBJECT,
|
||||
required: schema.required as string[],
|
||||
properties: Object.entries(schema.properties ?? {}).reduce<
|
||||
Record<string, Gemini.FunctionDeclarationSchemaProperty>
|
||||
|
|
|
@ -9,7 +9,6 @@ import { ElasticsearchClient, Logger } from '@kbn/core/server';
|
|||
import { Replacements } from '@kbn/elastic-assistant-common';
|
||||
import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen';
|
||||
import type { ActionsClientLlm } from '@kbn/langchain/server';
|
||||
import type { CompiledStateGraph } from '@langchain/langgraph';
|
||||
import { END, START, StateGraph } from '@langchain/langgraph';
|
||||
|
||||
import { CombinedPrompts } from './nodes/helpers/prompts';
|
||||
|
@ -18,11 +17,10 @@ import { getGenerateOrEndEdge } from './edges/generate_or_end';
|
|||
import { getGenerateOrRefineOrEndEdge } from './edges/generate_or_refine_or_end';
|
||||
import { getRefineOrEndEdge } from './edges/refine_or_end';
|
||||
import { getRetrieveAnonymizedAlertsOrGenerateEdge } from './edges/retrieve_anonymized_alerts_or_generate';
|
||||
import { getDefaultGraphState } from './state';
|
||||
import { getDefaultGraphAnnotation } from './state';
|
||||
import { getGenerateNode } from './nodes/generate';
|
||||
import { getRefineNode } from './nodes/refine';
|
||||
import { getRetrieveAnonymizedAlertsNode } from './nodes/retriever';
|
||||
import type { GraphState } from './types';
|
||||
|
||||
export interface GetDefaultAttackDiscoveryGraphParams {
|
||||
alertsIndexPattern?: string;
|
||||
|
@ -61,13 +59,9 @@ export const getDefaultAttackDiscoveryGraph = ({
|
|||
replacements,
|
||||
size,
|
||||
start,
|
||||
}: GetDefaultAttackDiscoveryGraphParams): CompiledStateGraph<
|
||||
GraphState,
|
||||
Partial<GraphState>,
|
||||
'generate' | 'refine' | 'retrieve_anonymized_alerts' | '__start__'
|
||||
> => {
|
||||
}: GetDefaultAttackDiscoveryGraphParams) => {
|
||||
try {
|
||||
const graphState = getDefaultGraphState({ end, filter, prompts, start });
|
||||
const graphState = getDefaultGraphAnnotation({ end, filter, prompts, start });
|
||||
|
||||
// get nodes:
|
||||
const retrieveAnonymizedAlertsNode = getRetrieveAnonymizedAlertsNode({
|
||||
|
@ -103,7 +97,7 @@ export const getDefaultAttackDiscoveryGraph = ({
|
|||
getRetrieveAnonymizedAlertsOrGenerateEdge(logger);
|
||||
|
||||
// create the graph:
|
||||
const graph = new StateGraph<GraphState>({ channels: graphState })
|
||||
const graph = new StateGraph(graphState)
|
||||
.addNode(NodeType.RETRIEVE_ANONYMIZED_ALERTS_NODE, retrieveAnonymizedAlertsNode)
|
||||
.addNode(NodeType.GENERATE_NODE, generateNode)
|
||||
.addNode(NodeType.REFINE_NODE, refineNode)
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getDefaultGraphState } from '.';
|
||||
import { getDefaultGraphAnnotation } from '.';
|
||||
import {
|
||||
DEFAULT_MAX_GENERATION_ATTEMPTS,
|
||||
DEFAULT_MAX_HALLUCINATION_FAILURES,
|
||||
|
@ -26,118 +26,122 @@ const prompts = {
|
|||
};
|
||||
describe('getDefaultGraphState', () => {
|
||||
it('returns the expected default attackDiscoveries', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.attackDiscoveries?.default?.()).toBeNull();
|
||||
expect(graphAnnotation.spec.attackDiscoveries.value).toBeNull();
|
||||
});
|
||||
|
||||
it('returns the expected default attackDiscoveryPrompt', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.attackDiscoveryPrompt?.default?.()).toEqual(defaultAttackDiscoveryPrompt);
|
||||
expect(graphAnnotation.spec.attackDiscoveryPrompt.value).toEqual(defaultAttackDiscoveryPrompt);
|
||||
});
|
||||
|
||||
it('returns the expected default empty collection of anonymizedAlerts', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.anonymizedAlerts?.default?.()).toHaveLength(0);
|
||||
expect(graphAnnotation.spec.anonymizedAlerts.value).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('returns the expected default combinedGenerations state', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.combinedGenerations?.default?.()).toBe('');
|
||||
expect(graphAnnotation.spec.combinedGenerations.value).toBe('');
|
||||
});
|
||||
|
||||
it('returns the expected default combinedRefinements state', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.combinedRefinements?.default?.()).toBe('');
|
||||
expect(graphAnnotation.spec.combinedRefinements.value).toBe('');
|
||||
});
|
||||
|
||||
it('returns the expected default errors state', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.errors?.default?.()).toHaveLength(0);
|
||||
expect(graphAnnotation.spec.errors.value).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('return the expected default generationAttempts state', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.generationAttempts?.default?.()).toBe(0);
|
||||
expect(graphAnnotation.spec.generationAttempts.value).toBe(0);
|
||||
});
|
||||
|
||||
it('returns the expected default generations state', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.generations?.default?.()).toHaveLength(0);
|
||||
expect(graphAnnotation.spec.generations.value).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('returns the expected default hallucinationFailures state', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.hallucinationFailures?.default?.()).toBe(0);
|
||||
expect(graphAnnotation.spec.hallucinationFailures.value).toBe(0);
|
||||
});
|
||||
|
||||
it('returns the expected default refinePrompt state', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.refinePrompt?.default?.()).toEqual(defaultRefinePrompt);
|
||||
expect(graphAnnotation.spec.refinePrompt.value).toEqual(defaultRefinePrompt);
|
||||
});
|
||||
|
||||
it('returns the expected default maxGenerationAttempts state', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.maxGenerationAttempts?.default?.()).toBe(DEFAULT_MAX_GENERATION_ATTEMPTS);
|
||||
expect(graphAnnotation.spec.maxGenerationAttempts.value).toBe(DEFAULT_MAX_GENERATION_ATTEMPTS);
|
||||
});
|
||||
|
||||
it('returns the expected default maxHallucinationFailures state', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
expect(state.maxHallucinationFailures?.default?.()).toBe(DEFAULT_MAX_HALLUCINATION_FAILURES);
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
expect(graphAnnotation.spec.maxHallucinationFailures.value).toBe(
|
||||
DEFAULT_MAX_HALLUCINATION_FAILURES
|
||||
);
|
||||
});
|
||||
|
||||
it('returns the expected default maxRepeatedGenerations state', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.maxRepeatedGenerations?.default?.()).toBe(DEFAULT_MAX_REPEATED_GENERATIONS);
|
||||
expect(graphAnnotation.spec.maxRepeatedGenerations.value).toBe(
|
||||
DEFAULT_MAX_REPEATED_GENERATIONS
|
||||
);
|
||||
});
|
||||
|
||||
it('returns the expected default refinements state', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.refinements?.default?.()).toHaveLength(0);
|
||||
expect(graphAnnotation.spec.refinements.value).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('returns the expected default replacements state', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.replacements?.default?.()).toEqual({});
|
||||
expect(graphAnnotation.spec.replacements.value).toEqual({});
|
||||
});
|
||||
|
||||
it('returns the expected default unrefinedResults state', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.unrefinedResults?.default?.()).toBeNull();
|
||||
expect(graphAnnotation.spec.unrefinedResults.value).toBeNull();
|
||||
});
|
||||
|
||||
it('returns the expected default end', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.end?.default?.()).toBeUndefined();
|
||||
expect(graphAnnotation.spec.end.value).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns the expected end when it is provided', () => {
|
||||
const end = '2025-01-02T00:00:00.000Z';
|
||||
|
||||
const state = getDefaultGraphState({ prompts, end });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts, end });
|
||||
|
||||
expect(state.end?.default?.()).toEqual(end);
|
||||
expect(graphAnnotation.spec.end.value).toEqual(end);
|
||||
});
|
||||
|
||||
it('returns the expected default filter to be undefined', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.filter?.default?.()).toBeUndefined();
|
||||
expect(graphAnnotation.spec.filter.value).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns the expected filter when it is provided', () => {
|
||||
|
@ -162,22 +166,22 @@ describe('getDefaultGraphState', () => {
|
|||
},
|
||||
};
|
||||
|
||||
const state = getDefaultGraphState({ prompts, filter });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts, filter });
|
||||
|
||||
expect(state.filter?.default?.()).toEqual(filter);
|
||||
expect(graphAnnotation.spec.filter.value).toEqual(filter);
|
||||
});
|
||||
|
||||
it('returns the expected default start to be undefined', () => {
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
|
||||
|
||||
expect(state.start?.default?.()).toBeUndefined();
|
||||
expect(graphAnnotation.spec.start.value).toBeUndefined();
|
||||
});
|
||||
|
||||
it('returns the expected start when it is provided', () => {
|
||||
const start = '2025-01-01T00:00:00.000Z';
|
||||
|
||||
const state = getDefaultGraphState({ prompts, start });
|
||||
const graphAnnotation = getDefaultGraphAnnotation({ prompts, start });
|
||||
|
||||
expect(state.start?.default?.()).toEqual(start);
|
||||
expect(graphAnnotation.spec.start.value).toEqual(start);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
import { AttackDiscovery, Replacements } from '@kbn/elastic-assistant-common';
|
||||
import type { Document } from '@langchain/core/documents';
|
||||
import type { StateGraphArgs } from '@langchain/langgraph';
|
||||
import { Annotation } from '@langchain/langgraph';
|
||||
|
||||
import { AttackDiscoveryPrompts } from '../nodes/helpers/prompts';
|
||||
import {
|
||||
|
@ -15,7 +15,6 @@ import {
|
|||
DEFAULT_MAX_HALLUCINATION_FAILURES,
|
||||
DEFAULT_MAX_REPEATED_GENERATIONS,
|
||||
} from '../constants';
|
||||
import type { GraphState } from '../types';
|
||||
|
||||
export interface Options {
|
||||
end?: string;
|
||||
|
@ -24,90 +23,86 @@ export interface Options {
|
|||
start?: string;
|
||||
}
|
||||
|
||||
export const getDefaultGraphState = ({
|
||||
end,
|
||||
filter,
|
||||
prompts,
|
||||
start,
|
||||
}: Options): StateGraphArgs<GraphState>['channels'] => ({
|
||||
attackDiscoveries: {
|
||||
value: (x: AttackDiscovery[] | null, y?: AttackDiscovery[] | null) => y ?? x,
|
||||
default: () => null,
|
||||
},
|
||||
attackDiscoveryPrompt: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => prompts.default,
|
||||
},
|
||||
anonymizedAlerts: {
|
||||
value: (x: Document[], y?: Document[]) => y ?? x,
|
||||
default: () => [],
|
||||
},
|
||||
combinedGenerations: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
},
|
||||
combinedRefinements: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
},
|
||||
continuePrompt: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => prompts.continue,
|
||||
},
|
||||
end: {
|
||||
value: (x?: string | null, y?: string | null) => y ?? x,
|
||||
default: () => end,
|
||||
},
|
||||
errors: {
|
||||
value: (x: string[], y?: string[]) => y ?? x,
|
||||
default: () => [],
|
||||
},
|
||||
filter: {
|
||||
value: (x?: Record<string, unknown> | null, y?: Record<string, unknown> | null) => y ?? x,
|
||||
default: () => filter,
|
||||
},
|
||||
generationAttempts: {
|
||||
value: (x: number, y?: number) => y ?? x,
|
||||
default: () => 0,
|
||||
},
|
||||
generations: {
|
||||
value: (x: string[], y?: string[]) => y ?? x,
|
||||
default: () => [],
|
||||
},
|
||||
hallucinationFailures: {
|
||||
value: (x: number, y?: number) => y ?? x,
|
||||
default: () => 0,
|
||||
},
|
||||
refinePrompt: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => prompts.refine,
|
||||
},
|
||||
maxGenerationAttempts: {
|
||||
value: (x: number, y?: number) => y ?? x,
|
||||
default: () => DEFAULT_MAX_GENERATION_ATTEMPTS,
|
||||
},
|
||||
maxHallucinationFailures: {
|
||||
value: (x: number, y?: number) => y ?? x,
|
||||
default: () => DEFAULT_MAX_HALLUCINATION_FAILURES,
|
||||
},
|
||||
maxRepeatedGenerations: {
|
||||
value: (x: number, y?: number) => y ?? x,
|
||||
default: () => DEFAULT_MAX_REPEATED_GENERATIONS,
|
||||
},
|
||||
refinements: {
|
||||
value: (x: string[], y?: string[]) => y ?? x,
|
||||
default: () => [],
|
||||
},
|
||||
replacements: {
|
||||
value: (x: Replacements, y?: Replacements) => y ?? x,
|
||||
default: () => ({}),
|
||||
},
|
||||
start: {
|
||||
value: (x?: string | null, y?: string | null) => y ?? x,
|
||||
default: () => start,
|
||||
},
|
||||
unrefinedResults: {
|
||||
value: (x: AttackDiscovery[] | null, y?: AttackDiscovery[] | null) => y ?? x,
|
||||
default: () => null,
|
||||
},
|
||||
});
|
||||
export const getDefaultGraphAnnotation = ({ end, filter, prompts, start }: Options) =>
|
||||
Annotation.Root({
|
||||
attackDiscoveries: Annotation<AttackDiscovery[] | null>({
|
||||
reducer: (x: AttackDiscovery[] | null, y?: AttackDiscovery[] | null) => y ?? x,
|
||||
default: () => null,
|
||||
}),
|
||||
attackDiscoveryPrompt: Annotation<string>({
|
||||
reducer: (x: string, y?: string) => y ?? x,
|
||||
default: () => prompts.default,
|
||||
}),
|
||||
anonymizedAlerts: Annotation<Document[]>({
|
||||
reducer: (x: Document[], y?: Document[]) => y ?? x,
|
||||
default: () => [],
|
||||
}),
|
||||
combinedGenerations: Annotation<string>({
|
||||
reducer: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
}),
|
||||
combinedRefinements: Annotation<string>({
|
||||
reducer: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
}),
|
||||
continuePrompt: Annotation<string, string>({
|
||||
reducer: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
}),
|
||||
end: Annotation<string | null | undefined>({
|
||||
reducer: (x?: string | null, y?: string | null) => y ?? x,
|
||||
default: () => end,
|
||||
}),
|
||||
errors: Annotation<string[], string[]>({
|
||||
reducer: (x: string[], y?: string[]) => y ?? x,
|
||||
default: () => [],
|
||||
}),
|
||||
filter: Annotation<Record<string, unknown> | null | undefined>({
|
||||
reducer: (x?: Record<string, unknown> | null, y?: Record<string, unknown> | null) => y ?? x,
|
||||
default: () => filter,
|
||||
}),
|
||||
generationAttempts: Annotation<number>({
|
||||
reducer: (x: number, y?: number) => y ?? x,
|
||||
default: () => 0,
|
||||
}),
|
||||
generations: Annotation<string[]>({
|
||||
reducer: (x: string[], y?: string[]) => y ?? x,
|
||||
default: () => [],
|
||||
}),
|
||||
hallucinationFailures: Annotation<number>({
|
||||
reducer: (x: number, y?: number) => y ?? x,
|
||||
default: () => 0,
|
||||
}),
|
||||
refinePrompt: Annotation<string>({
|
||||
reducer: (x: string, y?: string) => y ?? x,
|
||||
default: () => prompts.refine,
|
||||
}),
|
||||
maxGenerationAttempts: Annotation<number>({
|
||||
reducer: (x: number, y?: number) => y ?? x,
|
||||
default: () => DEFAULT_MAX_GENERATION_ATTEMPTS,
|
||||
}),
|
||||
maxHallucinationFailures: Annotation<number>({
|
||||
reducer: (x: number, y?: number) => y ?? x,
|
||||
default: () => DEFAULT_MAX_HALLUCINATION_FAILURES,
|
||||
}),
|
||||
maxRepeatedGenerations: Annotation<number>({
|
||||
reducer: (x: number, y?: number) => y ?? x,
|
||||
default: () => DEFAULT_MAX_REPEATED_GENERATIONS,
|
||||
}),
|
||||
refinements: Annotation<string[]>({
|
||||
reducer: (x: string[], y?: string[]) => y ?? x,
|
||||
default: () => [],
|
||||
}),
|
||||
replacements: Annotation<Replacements>({
|
||||
reducer: (x: Replacements, y?: Replacements) => y ?? x,
|
||||
default: () => ({}),
|
||||
}),
|
||||
start: Annotation<string | null | undefined, string | null | undefined>({
|
||||
reducer: (x?: string | null, y?: string | null) => y ?? x,
|
||||
default: () => start,
|
||||
}),
|
||||
unrefinedResults: Annotation<AttackDiscovery[] | null>({
|
||||
reducer: (x: AttackDiscovery[] | null, y?: AttackDiscovery[] | null) => y ?? x,
|
||||
default: () => null,
|
||||
}),
|
||||
});
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { END, START, StateGraph, StateGraphArgs } from '@langchain/langgraph';
|
||||
import { Annotation, END, START, StateGraph } from '@langchain/langgraph';
|
||||
import { AgentAction, AgentFinish, AgentStep } from '@langchain/core/agents';
|
||||
import { AgentRunnableSequence } from 'langchain/dist/agents/agent';
|
||||
import { StructuredTool } from '@langchain/core/tools';
|
||||
|
@ -60,72 +60,72 @@ export const getDefaultAssistantGraph = ({
|
|||
}: GetDefaultAssistantGraphParams) => {
|
||||
try {
|
||||
// Default graph state
|
||||
const graphState: StateGraphArgs<AgentState>['channels'] = {
|
||||
input: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
const graphAnnotation = Annotation.Root({
|
||||
input: Annotation<string>({
|
||||
reducer: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
},
|
||||
lastNode: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
}),
|
||||
lastNode: Annotation<string>({
|
||||
reducer: (x: string, y?: string) => y ?? x,
|
||||
default: () => 'start',
|
||||
},
|
||||
steps: {
|
||||
value: (x: AgentStep[], y: AgentStep[]) => x.concat(y),
|
||||
}),
|
||||
steps: Annotation<AgentStep[]>({
|
||||
reducer: (x: AgentStep[], y: AgentStep[]) => x.concat(y),
|
||||
default: () => [],
|
||||
},
|
||||
hasRespondStep: {
|
||||
value: (x: boolean, y?: boolean) => y ?? x,
|
||||
}),
|
||||
hasRespondStep: Annotation<boolean>({
|
||||
reducer: (x: boolean, y?: boolean) => y ?? x,
|
||||
default: () => false,
|
||||
},
|
||||
agentOutcome: {
|
||||
value: (
|
||||
}),
|
||||
agentOutcome: Annotation<AgentAction | AgentFinish | undefined>({
|
||||
reducer: (
|
||||
x: AgentAction | AgentFinish | undefined,
|
||||
y?: AgentAction | AgentFinish | undefined
|
||||
) => y ?? x,
|
||||
default: () => undefined,
|
||||
},
|
||||
messages: {
|
||||
value: (x: BaseMessage[], y: BaseMessage[]) => y ?? x,
|
||||
}),
|
||||
messages: Annotation<BaseMessage[]>({
|
||||
reducer: (x: BaseMessage[], y: BaseMessage[]) => y ?? x,
|
||||
default: () => [],
|
||||
},
|
||||
chatTitle: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
}),
|
||||
chatTitle: Annotation<string>({
|
||||
reducer: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
},
|
||||
llmType: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
}),
|
||||
llmType: Annotation<string>({
|
||||
reducer: (x: string, y?: string) => y ?? x,
|
||||
default: () => 'unknown',
|
||||
},
|
||||
isStream: {
|
||||
value: (x: boolean, y?: boolean) => y ?? x,
|
||||
}),
|
||||
isStream: Annotation<boolean>({
|
||||
reducer: (x: boolean, y?: boolean) => y ?? x,
|
||||
default: () => false,
|
||||
},
|
||||
isOssModel: {
|
||||
value: (x: boolean, y?: boolean) => y ?? x,
|
||||
}),
|
||||
isOssModel: Annotation<boolean>({
|
||||
reducer: (x: boolean, y?: boolean) => y ?? x,
|
||||
default: () => false,
|
||||
},
|
||||
connectorId: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
}),
|
||||
connectorId: Annotation<string>({
|
||||
reducer: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
},
|
||||
conversation: {
|
||||
value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
|
||||
}),
|
||||
conversation: Annotation<ConversationResponse | undefined>({
|
||||
reducer: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
|
||||
y ?? x,
|
||||
default: () => undefined,
|
||||
},
|
||||
conversationId: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
}),
|
||||
conversationId: Annotation<string>({
|
||||
reducer: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
},
|
||||
responseLanguage: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
}),
|
||||
responseLanguage: Annotation<string>({
|
||||
reducer: (x: string, y?: string) => y ?? x,
|
||||
default: () => 'English',
|
||||
},
|
||||
provider: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
}),
|
||||
provider: Annotation<string>({
|
||||
reducer: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
},
|
||||
};
|
||||
}),
|
||||
});
|
||||
|
||||
// Default node parameters
|
||||
const nodeParams: NodeParamsBase = {
|
||||
|
@ -135,9 +135,7 @@ export const getDefaultAssistantGraph = ({
|
|||
};
|
||||
|
||||
// Put together a new graph using default state from above
|
||||
const graph = new StateGraph({
|
||||
channels: graphState,
|
||||
})
|
||||
const graph = new StateGraph(graphAnnotation)
|
||||
.addNode(NodeType.GET_PERSISTED_CONVERSATION, (state: AgentState) =>
|
||||
getPersistedConversation({
|
||||
...nodeParams,
|
||||
|
|
|
@ -13,6 +13,7 @@ import type { KibanaRequest } from '@kbn/core-http-server';
|
|||
import type { ExecuteConnectorRequestBody, TraceData } from '@kbn/elastic-assistant-common';
|
||||
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
|
||||
import { AIMessageChunk } from '@langchain/core/messages';
|
||||
import { AgentFinish } from 'langchain/agents';
|
||||
import { withAssistantSpan } from '../../tracers/apm/with_assistant_span';
|
||||
import { AGENT_NODE_TAG } from './nodes/run_agent';
|
||||
import { DEFAULT_ASSISTANT_GRAPH_ID, DefaultAssistantGraph } from './graph';
|
||||
|
@ -234,7 +235,7 @@ export const invokeGraph = async ({
|
|||
};
|
||||
span.addLabels({ evaluationId: traceOptions?.evaluationId });
|
||||
}
|
||||
const r = await assistantGraph.invoke(inputs, {
|
||||
const result = await assistantGraph.invoke(inputs, {
|
||||
callbacks: [
|
||||
apmTracer,
|
||||
...(traceOptions?.tracers ?? []),
|
||||
|
@ -243,8 +244,8 @@ export const invokeGraph = async ({
|
|||
runName: DEFAULT_ASSISTANT_GRAPH_ID,
|
||||
tags: traceOptions?.tags ?? [],
|
||||
});
|
||||
const output = r.agentOutcome.returnValues.output;
|
||||
const conversationId = r.conversation?.id;
|
||||
const output = (result.agentOutcome as AgentFinish).returnValues.output;
|
||||
const conversationId = result.conversation?.id;
|
||||
if (onLlmResponse) {
|
||||
await onLlmResponse(output, traceData);
|
||||
}
|
||||
|
|
|
@ -89,6 +89,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
// prevents the agent from retrying on failure
|
||||
// failure could be due to bad connector, we should deliver that result to the client asap
|
||||
maxRetries: 0,
|
||||
convertSystemMessageToHumanContent: false,
|
||||
});
|
||||
|
||||
const anonymizationFieldsRes =
|
||||
|
|
|
@ -126,6 +126,7 @@ export const invokeAttackDiscoveryGraph = async ({
|
|||
tags,
|
||||
}
|
||||
);
|
||||
|
||||
const {
|
||||
attackDiscoveries,
|
||||
anonymizedAlerts,
|
||||
|
|
|
@ -27,6 +27,7 @@ import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/
|
|||
import { getDefaultArguments } from '@kbn/langchain/server';
|
||||
import { StructuredTool } from '@langchain/core/tools';
|
||||
import {
|
||||
AgentFinish,
|
||||
createOpenAIToolsAgent,
|
||||
createStructuredChatAgent,
|
||||
createToolCallingAgent,
|
||||
|
@ -254,6 +255,7 @@ export const postEvaluateRoute = (
|
|||
signal: abortSignal,
|
||||
streaming: false,
|
||||
maxRetries: 0,
|
||||
convertSystemMessageToHumanContent: false,
|
||||
});
|
||||
const llm = createLlmInstance();
|
||||
const anonymizationFieldsRes =
|
||||
|
@ -400,14 +402,14 @@ export const postEvaluateRoute = (
|
|||
const predict = async (input: { input: string }) => {
|
||||
logger.debug(`input:\n ${JSON.stringify(input, null, 2)}`);
|
||||
|
||||
const r = await graph.invoke(
|
||||
const result = await graph.invoke(
|
||||
{
|
||||
input: input.input,
|
||||
connectorId,
|
||||
conversationId: undefined,
|
||||
responseLanguage: 'English',
|
||||
llmType,
|
||||
isStreaming: false,
|
||||
isStream: false,
|
||||
isOssModel,
|
||||
}, // TODO: Update to use the correct input format per dataset type
|
||||
{
|
||||
|
@ -415,7 +417,7 @@ export const postEvaluateRoute = (
|
|||
tags: ['evaluation'],
|
||||
}
|
||||
);
|
||||
const output = r.agentOutcome.returnValues.output;
|
||||
const output = (result.agentOutcome as AgentFinish).returnValues.output;
|
||||
return output;
|
||||
};
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ import { Annotation, messagesStateReducer } from '@langchain/langgraph';
|
|||
import { uniq } from 'lodash/fp';
|
||||
import type { RuleTranslationResult } from '../../../../../../common/siem_migrations/constants';
|
||||
import type {
|
||||
ElasticRule,
|
||||
ElasticRulePartial,
|
||||
OriginalRule,
|
||||
RuleMigration,
|
||||
} from '../../../../../../common/siem_migrations/model/rule_migration.gen';
|
||||
|
@ -21,7 +21,7 @@ export const migrateRuleState = Annotation.Root({
|
|||
default: () => [],
|
||||
}),
|
||||
original_rule: Annotation<OriginalRule>(),
|
||||
elastic_rule: Annotation<ElasticRule>({
|
||||
elastic_rule: Annotation<ElasticRulePartial>({
|
||||
reducer: (state, action) => ({ ...state, ...action }),
|
||||
}),
|
||||
semantic_query: Annotation<string>({
|
||||
|
|
|
@ -40,7 +40,7 @@ export const translateRuleState = Annotation.Root({
|
|||
}),
|
||||
elastic_rule: Annotation<ElasticRulePartial>({
|
||||
reducer: (state, action) => ({ ...state, ...action }),
|
||||
default: () => ({}),
|
||||
default: () => ({} as ElasticRulePartial),
|
||||
}),
|
||||
validation_errors: Annotation<TranslateRuleValidationErrors>({
|
||||
reducer: (current, value) => value ?? current,
|
||||
|
|
|
@ -9,6 +9,7 @@ import assert from 'assert';
|
|||
import type { AuthenticatedUser, Logger } from '@kbn/core/server';
|
||||
import { abortSignalToPromise, AbortError } from '@kbn/kibana-utils-plugin/server';
|
||||
import type { RunnableConfig } from '@langchain/core/runnables';
|
||||
import type { ElasticRule } from '../../../../../common/siem_migrations/model/rule_migration.gen';
|
||||
import { SiemMigrationStatus } from '../../../../../common/siem_migrations/constants';
|
||||
import { initPromisePool } from '../../../../utils/promise_pool';
|
||||
import type { RuleMigrationsDataClient } from '../data/rule_migrations_data_client';
|
||||
|
@ -336,7 +337,7 @@ export class RuleMigrationTaskRunner {
|
|||
this.logger.debug(`Translation of rule "${ruleMigration.id}" succeeded`);
|
||||
const ruleMigrationTranslated = {
|
||||
...ruleMigration,
|
||||
elastic_rule: migrationResult.elastic_rule,
|
||||
elastic_rule: migrationResult.elastic_rule as ElasticRule,
|
||||
translation_result: migrationResult.translation_result,
|
||||
comments: migrationResult.comments,
|
||||
};
|
||||
|
|
|
@ -71,6 +71,7 @@ export class ActionsClientChat {
|
|||
llmType,
|
||||
model: connector.config?.defaultModel,
|
||||
streaming: false,
|
||||
convertSystemMessageToHumanContent: false,
|
||||
temperature: 0.05,
|
||||
maxRetries: 1, // Only retry once inside the model, we will handle backoff retries in the task runner
|
||||
telemetryMetadata: { pluginId: TELEMETRY_SIEM_MIGRATION_ID, aggregateBy: migrationId },
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue