[8.x] Update langchain (main) (#205553) (#212567)

# Backport

This will backport the following commits from `main` to `8.x`:
 - [Update langchain (main) #205553
](https://github.com/elastic/kibana/pull/205553)

<!--- Backport version: 9.6.6 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sorenlouv/backport)

Co-authored-by: elastic-renovate-prod[bot] <174716857+elastic-renovate-prod[bot]@users.noreply.github.com>
This commit is contained in:
Kenneth Kreindler 2025-02-27 13:35:22 +00:00 committed by GitHub
parent 4ee1e8756c
commit 22fc6d8bea
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 1004 additions and 1042 deletions

View file

@ -79,8 +79,8 @@
"resolutions": {
"**/@bazel/typescript/protobufjs": "6.11.4",
"**/@hello-pangea/dnd": "16.6.0",
"**/@langchain/core": "^0.3.16",
"**/@langchain/google-common": "^0.1.1",
"**/@langchain/core": "^0.3.40",
"**/@langchain/google-common": "^0.1.8",
"**/@types/node": "20.10.5",
"**/@typescript-eslint/utils": "5.62.0",
"**/chokidar": "^3.5.3",
@ -88,10 +88,14 @@
"**/globule/minimatch": "^3.1.2",
"**/hoist-non-react-statics": "^3.3.2",
"**/isomorphic-fetch/node-fetch": "^2.6.7",
"**/langchain": "^0.3.5",
"**/langchain": "^0.3.15",
"**/remark-parse/trim": "1.0.1",
"**/sharp": "0.32.6",
"**/typescript": "5.1.6",
"@aws-sdk/client-bedrock-agent-runtime": "^3.744.0",
"@aws-sdk/client-bedrock-runtime": "^3.744.0",
"@aws-sdk/client-kendra": "3.744.0",
"@aws-sdk/credential-provider-node": "3.744.0",
"@storybook/react-docgen-typescript-plugin": "1.0.6--canary.9.cd77847.0",
"@types/react": "~18.2.0",
"@types/react-dom": "~18.2.0",
@ -102,7 +106,7 @@
"@appland/sql-parser": "^1.5.1",
"@aws-crypto/sha256-js": "^5.2.0",
"@aws-crypto/util": "^5.2.0",
"@aws-sdk/client-bedrock-runtime": "^3.687.0",
"@aws-sdk/client-bedrock-runtime": "^3.744.0",
"@babel/runtime": "^7.24.7",
"@dagrejs/dagre": "^1.1.4",
"@dnd-kit/core": "^6.1.0",
@ -141,7 +145,7 @@
"@formatjs/intl-relativetimeformat": "^11.2.12",
"@formatjs/intl-utils": "^3.8.4",
"@formatjs/ts-transformer": "^3.13.14",
"@google/generative-ai": "^0.7.0",
"@google/generative-ai": "^0.21.0",
"@grpc/grpc-js": "^1.8.22",
"@hapi/accept": "^6.0.3",
"@hapi/boom": "^10.0.1",
@ -1023,14 +1027,14 @@
"@kbn/xstate-utils": "link:src/platform/packages/shared/kbn-xstate-utils",
"@kbn/zod": "link:src/platform/packages/shared/kbn-zod",
"@kbn/zod-helpers": "link:src/platform/packages/shared/kbn-zod-helpers",
"@langchain/aws": "^0.1.2",
"@langchain/community": "0.3.14",
"@langchain/core": "^0.3.16",
"@langchain/google-common": "^0.1.1",
"@langchain/google-genai": "^0.1.2",
"@langchain/google-vertexai": "^0.1.0",
"@langchain/langgraph": "0.2.19",
"@langchain/openai": "^0.3.11",
"@langchain/aws": "^0.1.3",
"@langchain/community": "^0.3.29",
"@langchain/core": "^0.3.40",
"@langchain/google-common": "^0.1.8",
"@langchain/google-genai": "^0.1.8",
"@langchain/google-vertexai": "^0.1.8",
"@langchain/langgraph": "^0.2.45",
"@langchain/openai": "^0.4.4",
"@langtrase/trace-attributes": "^7.5.0",
"@launchdarkly/node-server-sdk": "^9.7.2",
"@launchdarkly/openfeature-node-server": "^1.0.0",
@ -1057,14 +1061,11 @@
"@paralleldrive/cuid2": "^2.2.2",
"@reduxjs/toolkit": "1.9.7",
"@slack/webhook": "^7.0.1",
"@smithy/eventstream-codec": "^3.1.10",
"@smithy/eventstream-serde-node": "^3.0.12",
"@smithy/middleware-stack": "^3.0.11",
"@smithy/node-http-handler": "^3.3.3",
"@smithy/protocol-http": "^4.1.7",
"@smithy/signature-v4": "^4.2.3",
"@smithy/types": "^3.7.1",
"@smithy/util-utf8": "^3.0.0",
"@smithy/eventstream-codec": "^4.0.1",
"@smithy/eventstream-serde-node": "^4.0.1",
"@smithy/middleware-stack": "^4.0.1",
"@smithy/types": "^4.1.0",
"@smithy/util-utf8": "^4.0.0",
"@tanstack/react-query": "^4.29.12",
"@tanstack/react-query-devtools": "^4.29.12",
"@turf/along": "6.0.1",
@ -1176,8 +1177,8 @@
"jsonwebtoken": "^9.0.2",
"jsts": "^1.6.2",
"kea": "^2.6.0",
"langchain": "^0.3.5",
"langsmith": "^0.2.5",
"langchain": "^0.3.15",
"langsmith": "^0.3.7",
"launchdarkly-js-client-sdk": "^3.5.0",
"load-json-file": "^6.2.0",
"lodash": "^4.17.21",

View file

@ -703,4 +703,4 @@
"datasourceTemplate": "docker"
}
]
}
}

View file

@ -90,12 +90,6 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
streaming,
// matters only for the LangSmith logs (Metadata > Invocation Params), which are misleading if this is not set
modelName: model ?? DEFAULT_OPEN_AI_MODEL,
// these have to be initialized, but are not actually used since we override the openai client with the actions client
azureOpenAIApiKey: 'nothing',
azureOpenAIApiDeploymentName: 'nothing',
azureOpenAIApiInstanceName: 'nothing',
azureOpenAIBasePath: 'nothing',
azureOpenAIApiVersion: 'nothing',
openAIApiKey: '',
});
this.#actionsClient = actionsClient;

View file

@ -140,6 +140,7 @@ const callOptions = {
const handleLLMNewToken = jest.fn();
const callRunManager = {
handleLLMNewToken,
handleCustomEvent: jest.fn().mockResolvedValue({}),
} as unknown as CallbackManagerForLLMRun;
const onFailedAttempt = jest.fn();
const defaultArgs = {
@ -149,6 +150,7 @@ const defaultArgs = {
streaming: false,
maxRetries: 0,
onFailedAttempt,
convertSystemMessageToHumanContent: false,
};
const testMessage = 'Yes, your name is Andrew. How can I assist you further, Andrew?';
@ -188,7 +190,6 @@ describe('ActionsClientChatVertexAI', () => {
describe('_generate streaming: false', () => {
it('returns the expected content when _generate is invoked', async () => {
const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);
const result = await actionsClientChatVertexAI._generate(
callMessages,
callOptions,
@ -220,7 +221,7 @@ describe('ActionsClientChatVertexAI', () => {
expect(onFailedAttempt).toHaveBeenCalled();
});
it('rejects with the expected error the message has invalid content', async () => {
it('resolves to expected result when message has invalid content', async () => {
actionsClient.execute.mockImplementation(
jest.fn().mockResolvedValue({
data: {
@ -235,7 +236,7 @@ describe('ActionsClientChatVertexAI', () => {
await expect(
actionsClientChatVertexAI._generate(callMessages, callOptions, callRunManager)
).rejects.toThrowError("Cannot read properties of undefined (reading 'text')");
).resolves.toEqual({ generations: [], llmOutput: {} });
});
});

View file

@ -17,7 +17,7 @@ import { Readable } from 'stream';
import { Logger } from '@kbn/logging';
import { BaseChatModelParams } from '@langchain/core/language_models/chat_models';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { GeminiPartText } from '@langchain/google-common/dist/types';
import { GeminiPartText, GeminiRequest } from '@langchain/google-common/dist/types';
import type { TelemetryMetadata } from '@kbn/actions-plugin/server/lib';
import {
convertResponseBadFinishReasonToErrorMsg,
@ -36,6 +36,7 @@ export interface CustomChatModelInput extends BaseChatModelParams {
model?: string;
maxTokens?: number;
telemetryMetadata?: TelemetryMetadata;
convertSystemMessageToHumanContent?: boolean;
}
export class ActionsClientChatVertexAI extends ChatVertexAI {
@ -56,7 +57,9 @@ export class ActionsClientChatVertexAI extends ChatVertexAI {
this.#model = props.model;
this.#actionsClient = actionsClient;
this.#connectorId = connectorId;
const client = this.buildClient(props);
// apiKey is required to build the client but is overridden by the actionsClient
const client = this.buildClient({ apiKey: 'nothing', ...props });
this.connection = new ActionsClientChatConnection(
{
...this,
@ -80,7 +83,7 @@ export class ActionsClientChatVertexAI extends ChatVertexAI {
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const parameters = this.invocationParams(options);
const data = await this.connection.formatData(messages, parameters);
const data = (await this.connection.formatData(messages, parameters)) as GeminiRequest;
const stream = await this.caller.callWithOptions({ signal: options?.signal }, async () => {
const systemPart: GeminiPartText | undefined = data?.systemInstruction
?.parts?.[0] as unknown as GeminiPartText;

View file

@ -8,6 +8,7 @@
import {
ChatConnection,
GeminiContent,
GeminiRequest,
GoogleAbstractedClient,
GoogleAIBaseLLMInput,
GoogleLLMResponse,
@ -46,7 +47,7 @@ export class ActionsClientChatConnection<Auth> extends ChatConnection<Auth> {
this.temperature = fields.temperature ?? 0;
const nativeFormatData = this.formatData.bind(this);
this.formatData = async (data, options) => {
const result = await nativeFormatData(data, options);
const result = (await nativeFormatData(data, options)) as GeminiRequest;
if (result?.contents != null && result?.contents.length) {
// ensure there are not 2 messages in a row from the same role,
// if there are combine them

View file

@ -6,8 +6,7 @@
*/
import { BaseCallbackHandlerInput } from '@langchain/core/callbacks/base';
import type { Run } from 'langsmith/schemas';
import { BaseTracer } from '@langchain/core/tracers/base';
import { BaseTracer, Run } from '@langchain/core/tracers/base';
import agent from 'elastic-apm-node';
import type { Logger } from '@kbn/core/server';

View file

@ -6,7 +6,7 @@
*/
import { BaseCallbackHandlerInput } from '@langchain/core/callbacks/base';
import type { Run } from 'langsmith/schemas';
import { Run } from 'langsmith/schemas';
import { BaseTracer } from '@langchain/core/tracers/base';
import { AnalyticsServiceSetup, Logger } from '@kbn/core/server';

View file

@ -117,7 +117,7 @@ describe('geminiAdapter', () => {
name: 'myFunction',
parameters: {
properties: {},
type: 'OBJECT',
type: 'object',
},
},
{
@ -128,11 +128,11 @@ describe('geminiAdapter', () => {
foo: {
description: 'foo',
enum: undefined,
type: 'STRING',
type: 'string',
},
},
required: ['foo'],
type: 'OBJECT',
type: 'object',
},
},
],

View file

@ -110,7 +110,7 @@ function toolsToGemini(tools: ToolOptions['tools']): Gemini.Tool[] {
parameters: schema
? toolSchemaToGemini({ schema })
: {
type: Gemini.FunctionDeclarationSchemaType.OBJECT,
type: Gemini.SchemaType.OBJECT,
properties: {},
},
};
@ -130,13 +130,13 @@ function toolSchemaToGemini({ schema }: { schema: ToolSchema }): Gemini.Function
switch (def.type) {
case 'array':
return {
type: Gemini.FunctionDeclarationSchemaType.ARRAY,
type: Gemini.SchemaType.ARRAY,
description: def.description,
items: convertSchemaType({ def: def.items }) as Gemini.FunctionDeclarationSchema,
};
case 'object':
return {
type: Gemini.FunctionDeclarationSchemaType.OBJECT,
type: Gemini.SchemaType.OBJECT,
description: def.description,
required: def.required as string[],
properties: def.properties
@ -152,19 +152,19 @@ function toolSchemaToGemini({ schema }: { schema: ToolSchema }): Gemini.Function
};
case 'string':
return {
type: Gemini.FunctionDeclarationSchemaType.STRING,
type: Gemini.SchemaType.STRING,
description: def.description,
enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined,
};
case 'boolean':
return {
type: Gemini.FunctionDeclarationSchemaType.BOOLEAN,
type: Gemini.SchemaType.BOOLEAN,
description: def.description,
enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined,
};
case 'number':
return {
type: Gemini.FunctionDeclarationSchemaType.NUMBER,
type: Gemini.SchemaType.NUMBER,
description: def.description,
enum: def.enum ? (def.enum as string[]) : def.const ? [def.const] : undefined,
};
@ -172,7 +172,7 @@ function toolSchemaToGemini({ schema }: { schema: ToolSchema }): Gemini.Function
};
return {
type: Gemini.FunctionDeclarationSchemaType.OBJECT,
type: Gemini.SchemaType.OBJECT,
required: schema.required as string[],
properties: Object.entries(schema.properties ?? {}).reduce<
Record<string, Gemini.FunctionDeclarationSchemaProperty>

View file

@ -9,7 +9,6 @@ import { ElasticsearchClient, Logger } from '@kbn/core/server';
import { Replacements } from '@kbn/elastic-assistant-common';
import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen';
import type { ActionsClientLlm } from '@kbn/langchain/server';
import type { CompiledStateGraph } from '@langchain/langgraph';
import { END, START, StateGraph } from '@langchain/langgraph';
import { CombinedPrompts } from './nodes/helpers/prompts';
@ -18,11 +17,10 @@ import { getGenerateOrEndEdge } from './edges/generate_or_end';
import { getGenerateOrRefineOrEndEdge } from './edges/generate_or_refine_or_end';
import { getRefineOrEndEdge } from './edges/refine_or_end';
import { getRetrieveAnonymizedAlertsOrGenerateEdge } from './edges/retrieve_anonymized_alerts_or_generate';
import { getDefaultGraphState } from './state';
import { getDefaultGraphAnnotation } from './state';
import { getGenerateNode } from './nodes/generate';
import { getRefineNode } from './nodes/refine';
import { getRetrieveAnonymizedAlertsNode } from './nodes/retriever';
import type { GraphState } from './types';
export interface GetDefaultAttackDiscoveryGraphParams {
alertsIndexPattern?: string;
@ -61,13 +59,9 @@ export const getDefaultAttackDiscoveryGraph = ({
replacements,
size,
start,
}: GetDefaultAttackDiscoveryGraphParams): CompiledStateGraph<
GraphState,
Partial<GraphState>,
'generate' | 'refine' | 'retrieve_anonymized_alerts' | '__start__'
> => {
}: GetDefaultAttackDiscoveryGraphParams) => {
try {
const graphState = getDefaultGraphState({ end, filter, prompts, start });
const graphState = getDefaultGraphAnnotation({ end, filter, prompts, start });
// get nodes:
const retrieveAnonymizedAlertsNode = getRetrieveAnonymizedAlertsNode({
@ -103,7 +97,7 @@ export const getDefaultAttackDiscoveryGraph = ({
getRetrieveAnonymizedAlertsOrGenerateEdge(logger);
// create the graph:
const graph = new StateGraph<GraphState>({ channels: graphState })
const graph = new StateGraph(graphState)
.addNode(NodeType.RETRIEVE_ANONYMIZED_ALERTS_NODE, retrieveAnonymizedAlertsNode)
.addNode(NodeType.GENERATE_NODE, generateNode)
.addNode(NodeType.REFINE_NODE, refineNode)

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import { getDefaultGraphState } from '.';
import { getDefaultGraphAnnotation } from '.';
import {
DEFAULT_MAX_GENERATION_ATTEMPTS,
DEFAULT_MAX_HALLUCINATION_FAILURES,
@ -26,118 +26,122 @@ const prompts = {
};
describe('getDefaultGraphState', () => {
it('returns the expected default attackDiscoveries', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.attackDiscoveries?.default?.()).toBeNull();
expect(graphAnnotation.spec.attackDiscoveries.value).toBeNull();
});
it('returns the expected default attackDiscoveryPrompt', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.attackDiscoveryPrompt?.default?.()).toEqual(defaultAttackDiscoveryPrompt);
expect(graphAnnotation.spec.attackDiscoveryPrompt.value).toEqual(defaultAttackDiscoveryPrompt);
});
it('returns the expected default empty collection of anonymizedAlerts', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.anonymizedAlerts?.default?.()).toHaveLength(0);
expect(graphAnnotation.spec.anonymizedAlerts.value).toHaveLength(0);
});
it('returns the expected default combinedGenerations state', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.combinedGenerations?.default?.()).toBe('');
expect(graphAnnotation.spec.combinedGenerations.value).toBe('');
});
it('returns the expected default combinedRefinements state', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.combinedRefinements?.default?.()).toBe('');
expect(graphAnnotation.spec.combinedRefinements.value).toBe('');
});
it('returns the expected default errors state', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.errors?.default?.()).toHaveLength(0);
expect(graphAnnotation.spec.errors.value).toHaveLength(0);
});
it('return the expected default generationAttempts state', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.generationAttempts?.default?.()).toBe(0);
expect(graphAnnotation.spec.generationAttempts.value).toBe(0);
});
it('returns the expected default generations state', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.generations?.default?.()).toHaveLength(0);
expect(graphAnnotation.spec.generations.value).toHaveLength(0);
});
it('returns the expected default hallucinationFailures state', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.hallucinationFailures?.default?.()).toBe(0);
expect(graphAnnotation.spec.hallucinationFailures.value).toBe(0);
});
it('returns the expected default refinePrompt state', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.refinePrompt?.default?.()).toEqual(defaultRefinePrompt);
expect(graphAnnotation.spec.refinePrompt.value).toEqual(defaultRefinePrompt);
});
it('returns the expected default maxGenerationAttempts state', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.maxGenerationAttempts?.default?.()).toBe(DEFAULT_MAX_GENERATION_ATTEMPTS);
expect(graphAnnotation.spec.maxGenerationAttempts.value).toBe(DEFAULT_MAX_GENERATION_ATTEMPTS);
});
it('returns the expected default maxHallucinationFailures state', () => {
const state = getDefaultGraphState({ prompts });
expect(state.maxHallucinationFailures?.default?.()).toBe(DEFAULT_MAX_HALLUCINATION_FAILURES);
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(graphAnnotation.spec.maxHallucinationFailures.value).toBe(
DEFAULT_MAX_HALLUCINATION_FAILURES
);
});
it('returns the expected default maxRepeatedGenerations state', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.maxRepeatedGenerations?.default?.()).toBe(DEFAULT_MAX_REPEATED_GENERATIONS);
expect(graphAnnotation.spec.maxRepeatedGenerations.value).toBe(
DEFAULT_MAX_REPEATED_GENERATIONS
);
});
it('returns the expected default refinements state', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.refinements?.default?.()).toHaveLength(0);
expect(graphAnnotation.spec.refinements.value).toHaveLength(0);
});
it('returns the expected default replacements state', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.replacements?.default?.()).toEqual({});
expect(graphAnnotation.spec.replacements.value).toEqual({});
});
it('returns the expected default unrefinedResults state', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.unrefinedResults?.default?.()).toBeNull();
expect(graphAnnotation.spec.unrefinedResults.value).toBeNull();
});
it('returns the expected default end', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.end?.default?.()).toBeUndefined();
expect(graphAnnotation.spec.end.value).toBeUndefined();
});
it('returns the expected end when it is provided', () => {
const end = '2025-01-02T00:00:00.000Z';
const state = getDefaultGraphState({ prompts, end });
const graphAnnotation = getDefaultGraphAnnotation({ prompts, end });
expect(state.end?.default?.()).toEqual(end);
expect(graphAnnotation.spec.end.value).toEqual(end);
});
it('returns the expected default filter to be undefined', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.filter?.default?.()).toBeUndefined();
expect(graphAnnotation.spec.filter.value).toBeUndefined();
});
it('returns the expected filter when it is provided', () => {
@ -162,22 +166,22 @@ describe('getDefaultGraphState', () => {
},
};
const state = getDefaultGraphState({ prompts, filter });
const graphAnnotation = getDefaultGraphAnnotation({ prompts, filter });
expect(state.filter?.default?.()).toEqual(filter);
expect(graphAnnotation.spec.filter.value).toEqual(filter);
});
it('returns the expected default start to be undefined', () => {
const state = getDefaultGraphState({ prompts });
const graphAnnotation = getDefaultGraphAnnotation({ prompts });
expect(state.start?.default?.()).toBeUndefined();
expect(graphAnnotation.spec.start.value).toBeUndefined();
});
it('returns the expected start when it is provided', () => {
const start = '2025-01-01T00:00:00.000Z';
const state = getDefaultGraphState({ prompts, start });
const graphAnnotation = getDefaultGraphAnnotation({ prompts, start });
expect(state.start?.default?.()).toEqual(start);
expect(graphAnnotation.spec.start.value).toEqual(start);
});
});

View file

@ -7,7 +7,7 @@
import { AttackDiscovery, Replacements } from '@kbn/elastic-assistant-common';
import type { Document } from '@langchain/core/documents';
import type { StateGraphArgs } from '@langchain/langgraph';
import { Annotation } from '@langchain/langgraph';
import { AttackDiscoveryPrompts } from '../nodes/helpers/prompts';
import {
@ -15,7 +15,6 @@ import {
DEFAULT_MAX_HALLUCINATION_FAILURES,
DEFAULT_MAX_REPEATED_GENERATIONS,
} from '../constants';
import type { GraphState } from '../types';
export interface Options {
end?: string;
@ -24,90 +23,86 @@ export interface Options {
start?: string;
}
export const getDefaultGraphState = ({
end,
filter,
prompts,
start,
}: Options): StateGraphArgs<GraphState>['channels'] => ({
attackDiscoveries: {
value: (x: AttackDiscovery[] | null, y?: AttackDiscovery[] | null) => y ?? x,
default: () => null,
},
attackDiscoveryPrompt: {
value: (x: string, y?: string) => y ?? x,
default: () => prompts.default,
},
anonymizedAlerts: {
value: (x: Document[], y?: Document[]) => y ?? x,
default: () => [],
},
combinedGenerations: {
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
combinedRefinements: {
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
continuePrompt: {
value: (x: string, y?: string) => y ?? x,
default: () => prompts.continue,
},
end: {
value: (x?: string | null, y?: string | null) => y ?? x,
default: () => end,
},
errors: {
value: (x: string[], y?: string[]) => y ?? x,
default: () => [],
},
filter: {
value: (x?: Record<string, unknown> | null, y?: Record<string, unknown> | null) => y ?? x,
default: () => filter,
},
generationAttempts: {
value: (x: number, y?: number) => y ?? x,
default: () => 0,
},
generations: {
value: (x: string[], y?: string[]) => y ?? x,
default: () => [],
},
hallucinationFailures: {
value: (x: number, y?: number) => y ?? x,
default: () => 0,
},
refinePrompt: {
value: (x: string, y?: string) => y ?? x,
default: () => prompts.refine,
},
maxGenerationAttempts: {
value: (x: number, y?: number) => y ?? x,
default: () => DEFAULT_MAX_GENERATION_ATTEMPTS,
},
maxHallucinationFailures: {
value: (x: number, y?: number) => y ?? x,
default: () => DEFAULT_MAX_HALLUCINATION_FAILURES,
},
maxRepeatedGenerations: {
value: (x: number, y?: number) => y ?? x,
default: () => DEFAULT_MAX_REPEATED_GENERATIONS,
},
refinements: {
value: (x: string[], y?: string[]) => y ?? x,
default: () => [],
},
replacements: {
value: (x: Replacements, y?: Replacements) => y ?? x,
default: () => ({}),
},
start: {
value: (x?: string | null, y?: string | null) => y ?? x,
default: () => start,
},
unrefinedResults: {
value: (x: AttackDiscovery[] | null, y?: AttackDiscovery[] | null) => y ?? x,
default: () => null,
},
});
export const getDefaultGraphAnnotation = ({ end, filter, prompts, start }: Options) =>
Annotation.Root({
attackDiscoveries: Annotation<AttackDiscovery[] | null>({
reducer: (x: AttackDiscovery[] | null, y?: AttackDiscovery[] | null) => y ?? x,
default: () => null,
}),
attackDiscoveryPrompt: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => prompts.default,
}),
anonymizedAlerts: Annotation<Document[]>({
reducer: (x: Document[], y?: Document[]) => y ?? x,
default: () => [],
}),
combinedGenerations: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
combinedRefinements: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
continuePrompt: Annotation<string, string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
}),
end: Annotation<string | null | undefined>({
reducer: (x?: string | null, y?: string | null) => y ?? x,
default: () => end,
}),
errors: Annotation<string[], string[]>({
reducer: (x: string[], y?: string[]) => y ?? x,
default: () => [],
}),
filter: Annotation<Record<string, unknown> | null | undefined>({
reducer: (x?: Record<string, unknown> | null, y?: Record<string, unknown> | null) => y ?? x,
default: () => filter,
}),
generationAttempts: Annotation<number>({
reducer: (x: number, y?: number) => y ?? x,
default: () => 0,
}),
generations: Annotation<string[]>({
reducer: (x: string[], y?: string[]) => y ?? x,
default: () => [],
}),
hallucinationFailures: Annotation<number>({
reducer: (x: number, y?: number) => y ?? x,
default: () => 0,
}),
refinePrompt: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => prompts.refine,
}),
maxGenerationAttempts: Annotation<number>({
reducer: (x: number, y?: number) => y ?? x,
default: () => DEFAULT_MAX_GENERATION_ATTEMPTS,
}),
maxHallucinationFailures: Annotation<number>({
reducer: (x: number, y?: number) => y ?? x,
default: () => DEFAULT_MAX_HALLUCINATION_FAILURES,
}),
maxRepeatedGenerations: Annotation<number>({
reducer: (x: number, y?: number) => y ?? x,
default: () => DEFAULT_MAX_REPEATED_GENERATIONS,
}),
refinements: Annotation<string[]>({
reducer: (x: string[], y?: string[]) => y ?? x,
default: () => [],
}),
replacements: Annotation<Replacements>({
reducer: (x: Replacements, y?: Replacements) => y ?? x,
default: () => ({}),
}),
start: Annotation<string | null | undefined, string | null | undefined>({
reducer: (x?: string | null, y?: string | null) => y ?? x,
default: () => start,
}),
unrefinedResults: Annotation<AttackDiscovery[] | null>({
reducer: (x: AttackDiscovery[] | null, y?: AttackDiscovery[] | null) => y ?? x,
default: () => null,
}),
});

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import { END, START, StateGraph, StateGraphArgs } from '@langchain/langgraph';
import { Annotation, END, START, StateGraph } from '@langchain/langgraph';
import { AgentAction, AgentFinish, AgentStep } from '@langchain/core/agents';
import { AgentRunnableSequence } from 'langchain/dist/agents/agent';
import { StructuredTool } from '@langchain/core/tools';
@ -60,72 +60,72 @@ export const getDefaultAssistantGraph = ({
}: GetDefaultAssistantGraphParams) => {
try {
// Default graph state
const graphState: StateGraphArgs<AgentState>['channels'] = {
input: {
value: (x: string, y?: string) => y ?? x,
const graphAnnotation = Annotation.Root({
input: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
},
lastNode: {
value: (x: string, y?: string) => y ?? x,
}),
lastNode: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'start',
},
steps: {
value: (x: AgentStep[], y: AgentStep[]) => x.concat(y),
}),
steps: Annotation<AgentStep[]>({
reducer: (x: AgentStep[], y: AgentStep[]) => x.concat(y),
default: () => [],
},
hasRespondStep: {
value: (x: boolean, y?: boolean) => y ?? x,
}),
hasRespondStep: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
agentOutcome: {
value: (
}),
agentOutcome: Annotation<AgentAction | AgentFinish | undefined>({
reducer: (
x: AgentAction | AgentFinish | undefined,
y?: AgentAction | AgentFinish | undefined
) => y ?? x,
default: () => undefined,
},
messages: {
value: (x: BaseMessage[], y: BaseMessage[]) => y ?? x,
}),
messages: Annotation<BaseMessage[]>({
reducer: (x: BaseMessage[], y: BaseMessage[]) => y ?? x,
default: () => [],
},
chatTitle: {
value: (x: string, y?: string) => y ?? x,
}),
chatTitle: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
},
llmType: {
value: (x: string, y?: string) => y ?? x,
}),
llmType: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'unknown',
},
isStream: {
value: (x: boolean, y?: boolean) => y ?? x,
}),
isStream: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
isOssModel: {
value: (x: boolean, y?: boolean) => y ?? x,
}),
isOssModel: Annotation<boolean>({
reducer: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
connectorId: {
value: (x: string, y?: string) => y ?? x,
}),
connectorId: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
},
conversation: {
value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
}),
conversation: Annotation<ConversationResponse | undefined>({
reducer: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
y ?? x,
default: () => undefined,
},
conversationId: {
value: (x: string, y?: string) => y ?? x,
}),
conversationId: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
},
responseLanguage: {
value: (x: string, y?: string) => y ?? x,
}),
responseLanguage: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => 'English',
},
provider: {
value: (x: string, y?: string) => y ?? x,
}),
provider: Annotation<string>({
reducer: (x: string, y?: string) => y ?? x,
default: () => '',
},
};
}),
});
// Default node parameters
const nodeParams: NodeParamsBase = {
@ -135,9 +135,7 @@ export const getDefaultAssistantGraph = ({
};
// Put together a new graph using default state from above
const graph = new StateGraph({
channels: graphState,
})
const graph = new StateGraph(graphAnnotation)
.addNode(NodeType.GET_PERSISTED_CONVERSATION, (state: AgentState) =>
getPersistedConversation({
...nodeParams,

View file

@ -13,6 +13,7 @@ import type { KibanaRequest } from '@kbn/core-http-server';
import type { ExecuteConnectorRequestBody, TraceData } from '@kbn/elastic-assistant-common';
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
import { AIMessageChunk } from '@langchain/core/messages';
import { AgentFinish } from 'langchain/agents';
import { withAssistantSpan } from '../../tracers/apm/with_assistant_span';
import { AGENT_NODE_TAG } from './nodes/run_agent';
import { DEFAULT_ASSISTANT_GRAPH_ID, DefaultAssistantGraph } from './graph';
@ -234,7 +235,7 @@ export const invokeGraph = async ({
};
span.addLabels({ evaluationId: traceOptions?.evaluationId });
}
const r = await assistantGraph.invoke(inputs, {
const result = await assistantGraph.invoke(inputs, {
callbacks: [
apmTracer,
...(traceOptions?.tracers ?? []),
@ -243,8 +244,8 @@ export const invokeGraph = async ({
runName: DEFAULT_ASSISTANT_GRAPH_ID,
tags: traceOptions?.tags ?? [],
});
const output = r.agentOutcome.returnValues.output;
const conversationId = r.conversation?.id;
const output = (result.agentOutcome as AgentFinish).returnValues.output;
const conversationId = result.conversation?.id;
if (onLlmResponse) {
await onLlmResponse(output, traceData);
}

View file

@ -89,6 +89,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
// prevents the agent from retrying on failure
// failure could be due to bad connector, we should deliver that result to the client asap
maxRetries: 0,
convertSystemMessageToHumanContent: false,
});
const anonymizationFieldsRes =

View file

@ -126,6 +126,7 @@ export const invokeAttackDiscoveryGraph = async ({
tags,
}
);
const {
attackDiscoveries,
anonymizedAlerts,

View file

@ -27,6 +27,7 @@ import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/
import { getDefaultArguments } from '@kbn/langchain/server';
import { StructuredTool } from '@langchain/core/tools';
import {
AgentFinish,
createOpenAIToolsAgent,
createStructuredChatAgent,
createToolCallingAgent,
@ -254,6 +255,7 @@ export const postEvaluateRoute = (
signal: abortSignal,
streaming: false,
maxRetries: 0,
convertSystemMessageToHumanContent: false,
});
const llm = createLlmInstance();
const anonymizationFieldsRes =
@ -400,14 +402,14 @@ export const postEvaluateRoute = (
const predict = async (input: { input: string }) => {
logger.debug(`input:\n ${JSON.stringify(input, null, 2)}`);
const r = await graph.invoke(
const result = await graph.invoke(
{
input: input.input,
connectorId,
conversationId: undefined,
responseLanguage: 'English',
llmType,
isStreaming: false,
isStream: false,
isOssModel,
}, // TODO: Update to use the correct input format per dataset type
{
@ -415,7 +417,7 @@ export const postEvaluateRoute = (
tags: ['evaluation'],
}
);
const output = r.agentOutcome.returnValues.output;
const output = (result.agentOutcome as AgentFinish).returnValues.output;
return output;
};

View file

@ -10,7 +10,7 @@ import { Annotation, messagesStateReducer } from '@langchain/langgraph';
import { uniq } from 'lodash/fp';
import type { RuleTranslationResult } from '../../../../../../common/siem_migrations/constants';
import type {
ElasticRule,
ElasticRulePartial,
OriginalRule,
RuleMigration,
} from '../../../../../../common/siem_migrations/model/rule_migration.gen';
@ -21,7 +21,7 @@ export const migrateRuleState = Annotation.Root({
default: () => [],
}),
original_rule: Annotation<OriginalRule>(),
elastic_rule: Annotation<ElasticRule>({
elastic_rule: Annotation<ElasticRulePartial>({
reducer: (state, action) => ({ ...state, ...action }),
}),
semantic_query: Annotation<string>({

View file

@ -40,7 +40,7 @@ export const translateRuleState = Annotation.Root({
}),
elastic_rule: Annotation<ElasticRulePartial>({
reducer: (state, action) => ({ ...state, ...action }),
default: () => ({}),
default: () => ({} as ElasticRulePartial),
}),
validation_errors: Annotation<TranslateRuleValidationErrors>({
reducer: (current, value) => value ?? current,

View file

@ -9,6 +9,7 @@ import assert from 'assert';
import type { AuthenticatedUser, Logger } from '@kbn/core/server';
import { abortSignalToPromise, AbortError } from '@kbn/kibana-utils-plugin/server';
import type { RunnableConfig } from '@langchain/core/runnables';
import type { ElasticRule } from '../../../../../common/siem_migrations/model/rule_migration.gen';
import { SiemMigrationStatus } from '../../../../../common/siem_migrations/constants';
import { initPromisePool } from '../../../../utils/promise_pool';
import type { RuleMigrationsDataClient } from '../data/rule_migrations_data_client';
@ -336,7 +337,7 @@ export class RuleMigrationTaskRunner {
this.logger.debug(`Translation of rule "${ruleMigration.id}" succeeded`);
const ruleMigrationTranslated = {
...ruleMigration,
elastic_rule: migrationResult.elastic_rule,
elastic_rule: migrationResult.elastic_rule as ElasticRule,
translation_result: migrationResult.translation_result,
comments: migrationResult.comments,
};

View file

@ -71,6 +71,7 @@ export class ActionsClientChat {
llmType,
model: connector.config?.defaultModel,
streaming: false,
convertSystemMessageToHumanContent: false,
temperature: 0.05,
maxRetries: 1, // Only retry once inside the model, we will handle backoff retries in the task runner
telemetryMetadata: { pluginId: TELEMETRY_SIEM_MIGRATION_ID, aggregateBy: migrationId },

1539
yarn.lock

File diff suppressed because it is too large Load diff