[Security solution] [Ai Assistant] ES|QL generation with self healing (#213726)

This commit is contained in:
Kenneth Kreindler 2025-04-16 09:12:49 +01:00 committed by GitHub
parent ec88cca373
commit 1d430d4d35
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
71 changed files with 3382 additions and 173 deletions

View file

@ -80,3 +80,16 @@ export const ATTACK_DISCOVERY_SCHEDULES_BY_ID_ENABLE =
export const ATTACK_DISCOVERY_SCHEDULES_BY_ID_DISABLE =
`${ATTACK_DISCOVERY_SCHEDULES}/{id}/_disable` as const;
export const ATTACK_DISCOVERY_SCHEDULES_FIND = `${ATTACK_DISCOVERY_SCHEDULES}/_find` as const;
/**
* The server timeout is set to 4 minutes to allow for long-running requests.
* The allows slower LLMs (like Llama 3.1 70B) and complex tasks such as ESQL generation to complete
* without being interrupted.
*/
export const INVOKE_LLM_SERVER_TIMEOUT = 4 * 60 * 1000; // 4 minutes
/**
* The client timeout is set to 3 seconds less than the server timeout to prevent
* the `core-http-browser` from retrying the request.
*
*/
export const INVOKE_LL_CLIENT_TIMEOUT = INVOKE_LLM_SERVER_TIMEOUT - 3000; // 4 minutes - 3 second

View file

@ -21,4 +21,5 @@ export type AssistantFeatureKey = keyof AssistantFeatures;
export const defaultAssistantFeatures = Object.freeze({
assistantModelEvaluation: false,
defendInsights: true,
advancedEsqlGeneration: false,
});

View file

@ -7,21 +7,12 @@
import { HttpSetup } from '@kbn/core-http-browser';
import { useCallback, useRef, useState } from 'react';
import { ApiConfig, Replacements } from '@kbn/elastic-assistant-common';
import { ApiConfig, INVOKE_LL_CLIENT_TIMEOUT, Replacements } from '@kbn/elastic-assistant-common';
import moment from 'moment';
import { useAssistantContext } from '../../assistant_context';
import { fetchConnectorExecuteAction, FetchConnectorExecuteResponse } from '../api';
import * as i18n from './translations';
/**
* TODO: This is a workaround to solve the issue with the long standing server tasks while cahtting with the assistant.
* Some models (like Llama 3.1 70B) can perform poorly and be slow which leads to a long time to handle the request.
* The `core-http-browser` has a timeout of two minutes after which it will re-try the request. In combination with the slow model it can lead to
* a situation where core http client will initiate same request again and again.
* To avoid this, we abort http request after timeout which is slightly below two minutes.
*/
const EXECUTE_ACTION_TIMEOUT = 110 * 1000; // in milliseconds
interface SendMessageProps {
apiConfig: ApiConfig;
http: HttpSetup;
@ -52,7 +43,7 @@ export const useSendMessage = (): UseSendMessage => {
const timeoutId = setTimeout(() => {
abortController.current.abort(i18n.FETCH_MESSAGE_TIMEOUT_ERROR);
abortController.current = new AbortController();
}, EXECUTE_ACTION_TIMEOUT);
}, INVOKE_LL_CLIENT_TIMEOUT);
try {
return await fetchConnectorExecuteAction({

View file

@ -96,7 +96,7 @@ export interface UseAssistantContext {
actionTypeRegistry: ActionTypeRegistryContract;
alertsIndexPattern: string | undefined;
assistantAvailability: AssistantAvailability;
assistantFeatures: AssistantFeatures;
assistantFeatures: Partial<AssistantFeatures>;
assistantStreamingEnabled: boolean;
assistantTelemetry?: AssistantTelemetry;
augmentMessageCodeBlocks: (

View file

@ -10,11 +10,13 @@ import { Logger } from '@kbn/core/server';
import type { ActionsClient } from '@kbn/actions-plugin/server';
import { get } from 'lodash/fp';
import type { TelemetryMetadata } from '@kbn/actions-plugin/server/lib';
import { ChatOpenAI } from '@langchain/openai';
import { ChatOpenAI, OpenAIClient } from '@langchain/openai';
import { Stream } from 'openai/streaming';
import type OpenAI from 'openai';
import { PublicMethodsOf } from '@kbn/utility-types';
import { parseChatCompletion } from 'openai/lib/parser';
import { ChatCompletionCreateParams } from 'openai/resources';
import { DEFAULT_OPEN_AI_MODEL, DEFAULT_TIMEOUT } from './constants';
import {
InferenceChatCompleteParamsSchema,
@ -125,6 +127,17 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
return 'base_chat_model';
}
async betaParsedCompletionWithRetry(
request: OpenAI.ChatCompletionCreateParamsNonStreaming
): Promise<ReturnType<OpenAIClient['beta']['chat']['completions']['parse']>> {
return this.completionWithRetry(request).then((response) =>
parseChatCompletion(
response,
this.constructBody(request, this.llmType) as ChatCompletionCreateParams
)
);
}
async completionWithRetry(
request: OpenAI.ChatCompletionCreateParamsStreaming
): Promise<AsyncIterable<OpenAI.ChatCompletionChunk>>;
@ -132,7 +145,6 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
async completionWithRetry(
request: OpenAI.ChatCompletionCreateParamsNonStreaming
): Promise<OpenAI.ChatCompletion>;
async completionWithRetry(
completionRequest:
| OpenAI.ChatCompletionCreateParamsStreaming
@ -195,29 +207,8 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
};
signal?: AbortSignal;
} {
const body = {
temperature: this.#temperature,
// possible client model override
// security sends this from connectors, it is only missing from preconfigured connectors
// this should be undefined otherwise so the connector handles the model (stack_connector has access to preconfigured connector model values)
...(llmType === 'inference' ? {} : { model: this.model }),
n: completionRequest.n,
stop: completionRequest.stop,
tools: completionRequest.tools,
...(completionRequest.tool_choice ? { tool_choice: completionRequest.tool_choice } : {}),
// deprecated, use tools
...(completionRequest.functions ? { functions: completionRequest?.functions } : {}),
// ensure we take the messages from the completion request, not the client request
messages: completionRequest.messages.map((message) => ({
role: message.role,
content: message.content ?? '',
...('name' in message ? { name: message?.name } : {}),
...('tool_calls' in message ? { tool_calls: message?.tool_calls } : {}),
...('tool_call_id' in message ? { tool_call_id: message?.tool_call_id } : {}),
// deprecated, use tool_calls
...('function_call' in message ? { function_call: message?.function_call } : {}),
})),
};
const body = this.constructBody(completionRequest, llmType);
const subAction =
llmType === 'inference'
? completionRequest.stream
@ -248,4 +239,40 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
signal: this.#signal,
};
}
constructBody(
completionRequest:
| OpenAI.ChatCompletionCreateParamsNonStreaming
| OpenAI.ChatCompletionCreateParamsStreaming,
llmType: string
) {
const body = {
temperature: this.#temperature,
// possible client model override
// security sends this from connectors, it is only missing from preconfigured connectors
// this should be undefined otherwise so the connector handles the model (stack_connector has access to preconfigured connector model values)
...(llmType === 'inference' ? {} : { model: this.model }),
n: completionRequest.n,
stop: completionRequest.stop,
tools: completionRequest.tools,
...(completionRequest.response_format
? { response_format: completionRequest.response_format }
: {}),
...(completionRequest.tool_choice ? { tool_choice: completionRequest.tool_choice } : {}),
// deprecated, use tools
...(completionRequest.functions ? { functions: completionRequest?.functions } : {}),
// ensure we take the messages from the completion request, not the client request
messages: completionRequest.messages.map((message) => ({
role: message.role,
content: message.content ?? '',
...('name' in message ? { name: message?.name } : {}),
...('tool_calls' in message ? { tool_calls: message?.tool_calls } : {}),
...('tool_call_id' in message ? { tool_call_id: message?.tool_call_id } : {}),
// deprecated, use tool_calls
...('function_call' in message ? { function_call: message?.function_call } : {}),
})),
};
return body;
}
}

View file

@ -150,6 +150,7 @@ export const InvokeAIActionParamsSchema = schema.object({
schema.nullable(schema.oneOf([schema.string(), schema.arrayOf(schema.string())]))
),
temperature: schema.maybe(schema.number()),
response_format: schema.maybe(schema.any()),
// abort signal from client
signal: schema.maybe(schema.any()),
timeout: schema.maybe(schema.number()),

View file

@ -27,7 +27,7 @@ import type { LlmTasksPluginStart } from '@kbn/llm-tasks-plugin/server';
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
import { CoreRequestHandlerContext } from '@kbn/core/server';
import { ResponseBody } from '../types';
import type { AssistantTool } from '../../../types';
import type { AssistantTool, ElasticAssistantApiRequestHandlerContext } from '../../../types';
import { AIAssistantKnowledgeBaseDataClient } from '../../../ai_assistant_data_clients/knowledge_base';
import { AIAssistantConversationsDataClient } from '../../../ai_assistant_data_clients/conversations';
import { AIAssistantDataClient } from '../../../ai_assistant_data_clients';
@ -46,6 +46,7 @@ export interface AssistantDataClients {
export interface AgentExecutorParams<T extends boolean> {
abortSignal?: AbortSignal;
assistantContext: ElasticAssistantApiRequestHandlerContext;
alertsIndexPattern?: string;
actionsClient: PublicMethodsOf<ActionsClient>;
assistantTools?: AssistantTool[];

View file

@ -30,6 +30,7 @@ import { agentRunableFactory } from './agentRunnable';
export const callAssistantGraph: AgentExecutor<true | false> = async ({
abortSignal,
assistantContext,
actionsClient,
alertsIndexPattern,
assistantTools = [],
@ -113,6 +114,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
// Fetch any applicable tools that the source plugin may have registered
const assistantToolParams: AssistantToolParams = {
alertsIndexPattern,
assistantContext,
anonymizationFields,
connectorId,
contentReferencesStore,
@ -127,6 +129,8 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
request,
size,
telemetry,
createLlmInstance,
isOssModel,
};
const tools: StructuredTool[] = (

View file

@ -28,10 +28,36 @@ export const localToolPrompts: Prompt[] = [
- breakdown or filter ES|QL queries that are displayed on the current page
- convert queries from another language to ES|QL
- asks general questions about ES|QL
ALWAYS use this tool to generate ES|QL queries or explain anything about the ES|QL query language rather than coming up with your own answer.`,
},
},
{
promptId: 'GenerateESQLTool',
promptGroupId,
prompt: {
default: `You MUST use the "GenerateESQLTool" function when the user wants to:
- generate an ES|QL query
- convert queries from another language to ES|QL they can run on their cluster
ALWAYS use this tool to generate ES|QL queries and never generate ES|QL any other way.`,
},
},
{
promptId: 'AskAboutEsqlTool',
promptGroupId,
prompt: {
default: `You MUST use the "AskAboutEsqlTool" function when the user:
- asks for help with ES|QL
- asks about ES|QL syntax
- asks for ES|QL examples
- asks for ES|QL documentation
- asks for ES|QL best practices
- asks for ES|QL optimization
Never use this tool when they user wants to generate a ES|QL for their data.`,
},
},
{
promptId: 'ProductDocumentationTool',
promptGroupId,

View file

@ -19,6 +19,7 @@ import {
newContentReferencesStore,
pruneContentReferences,
ChatCompleteRequestQuery,
INVOKE_LLM_SERVER_TIMEOUT,
} from '@kbn/elastic-assistant-common';
import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common';
import { getRequestAbortedSignal } from '@kbn/data-plugin/server';
@ -54,6 +55,11 @@ export const chatCompleteRoute = (
requiredPrivileges: ['elasticAssistant'],
},
},
options: {
timeout: {
idleSocket: INVOKE_LLM_SERVER_TIMEOUT,
},
},
})
.addVersion(
{

View file

@ -352,6 +352,7 @@ export const postEvaluateRoute = (
// Fetch any applicable tools that the source plugin may have registered
const assistantToolParams: AssistantToolParams = {
anonymizationFields,
assistantContext,
esClient,
isEnabledKnowledgeBase,
kbDataClient: dataClients?.kbDataClient,
@ -368,6 +369,7 @@ export const postEvaluateRoute = (
size,
telemetry: ctx.elasticAssistant.telemetry,
...(productDocsAvailable ? { llmTasks: ctx.elasticAssistant.llmTasks } : {}),
createLlmInstance,
};
const tools: StructuredTool[] = (
@ -484,7 +486,7 @@ export const postEvaluateRoute = (
experimentPrefix: name,
client: new Client({ apiKey: langSmithApiKey }),
// prevent rate limiting and unexpected multiple experiment runs
maxConcurrency: 5,
maxConcurrency: 3,
})
.then((output) => {
logger.debug(`runResp:\n ${JSON.stringify(output, null, 2)}`);

View file

@ -17,7 +17,7 @@ export const logsIndexCreateRequest: IndicesCreateRequest = {
destination: {
properties: {
ip: { type: 'ip' },
port: { type: 'integer' },
port: { type: 'long' },
address: { type: 'keyword' },
},
},
@ -62,6 +62,7 @@ export const logsIndexCreateRequest: IndicesCreateRequest = {
network: {
properties: {
bytes: { type: 'long' },
direction: { type: 'keyword' },
},
},
process: {

View file

@ -319,6 +319,7 @@ export const langChainExecute = async ({
// Shared executor params
const executorParams: AgentExecutorParams<boolean> = {
abortSignal,
assistantContext,
dataClients,
alertsIndexPattern: request.body.alertsIndexPattern,
core: context.core,

View file

@ -18,6 +18,7 @@ import {
Replacements,
pruneContentReferences,
ExecuteConnectorRequestQuery,
INVOKE_LLM_SERVER_TIMEOUT,
POST_ACTIONS_CONNECTOR_EXECUTE,
} from '@kbn/elastic-assistant-common';
import { buildRouteValidationWithZod } from '@kbn/elastic-assistant-common/impl/schemas/common';
@ -46,6 +47,11 @@ export const postActionsConnectorExecuteRoute = (
requiredPrivileges: ['elasticAssistant'],
},
},
options: {
timeout: {
idleSocket: INVOKE_LLM_SERVER_TIMEOUT,
},
},
})
.addVersion(
{

View file

@ -248,6 +248,7 @@ export type AssistantToolLlm =
export interface AssistantToolParams {
alertsIndexPattern?: string;
assistantContext?: ElasticAssistantApiRequestHandlerContext;
anonymizationFields?: AnonymizationFieldResponse[];
inference?: InferenceServerStart;
isEnabledKnowledgeBase: boolean;
@ -270,4 +271,8 @@ export interface AssistantToolParams {
>;
size?: number;
telemetry?: AnalyticsServiceSetup;
createLlmInstance?: () =>
| ActionsClientChatBedrockConverse
| ActionsClientChatVertexAI
| ActionsClientChatOpenAI;
}

View file

@ -114,6 +114,11 @@ export const allowedExperimentalValues = Object.freeze({
*/
assistantModelEvaluation: false,
/**
* Enables advanced ESQL generation for the Assistant.
*/
advancedEsqlGeneration: false,
/**
* Enables the Managed User section inside the new user details flyout.
*/

Binary file not shown.

After

Width:  |  Height:  |  Size: 102 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 48 KiB

After

Width:  |  Height:  |  Size: 48 KiB

View file

@ -27,7 +27,7 @@
"test:generate:serverless-dev": "NODE_TLS_REJECT_UNAUTHORIZED=0 node --no-warnings scripts/endpoint/resolver_generator --node https://elastic_serverless:changeme@127.0.0.1:9200 --kibana http://elastic_serverless:changeme@127.0.0.1:5601",
"mappings:generate": "node scripts/mappings/mappings_generator",
"mappings:load": "node scripts/mappings/mappings_loader",
"siem-migrations:graph:draw": "node scripts/siem_migration/draw_graphs",
"langgraph:draw": "node scripts/langgraph/draw_graphs",
"junit:transform": "node scripts/junit_transformer --pathPattern '../../../../../target/kibana-security-solution/cypress/results/*.xml' --rootDirectory ../../../../../ --reportName 'Security Solution Cypress' --writeInPlace",
"openapi:generate": "node scripts/openapi/generate",
"openapi:generate:debug": "node --inspect-brk scripts/openapi/generate",

View file

@ -7,6 +7,7 @@
import type {
ActionsClientChatOpenAI,
ActionsClientChatVertexAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import type { Logger } from '@kbn/logging';
@ -14,6 +15,10 @@ import { ToolingLog } from '@kbn/tooling-log';
import { FakeLLM } from '@langchain/core/utils/testing';
import fs from 'fs/promises';
import path from 'path';
import type { ElasticsearchClient, KibanaRequest } from '@kbn/core/server';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import type { ActionsClientChatBedrockConverse } from '@kbn/langchain/server';
import { getGenerateEsqlGraph as getGenerateEsqlAgent } from '../../server/assistant/tools/esql/graphs/generate_esql/generate_esql';
import { getRuleMigrationAgent } from '../../server/lib/siem_migrations/rules/task/agent';
import type { RuleMigrationsRetriever } from '../../server/lib/siem_migrations/rules/task/retrievers';
import type { EsqlKnowledgeBase } from '../../server/lib/siem_migrations/rules/task/util/esql_knowledge_base';
@ -34,7 +39,7 @@ const createLlmInstance = () => {
return mockLlm;
};
async function getAgentGraph(logger: Logger): Promise<Drawable> {
async function getSiemMigrationGraph(logger: Logger): Promise<Drawable> {
const model = createLlmInstance();
const telemetryClient = {} as SiemMigrationTelemetryClient;
const graph = getRuleMigrationAgent({
@ -47,6 +52,22 @@ async function getAgentGraph(logger: Logger): Promise<Drawable> {
return graph.getGraphAsync({ xray: true });
}
async function getGenerateEsqlGraph(logger: Logger): Promise<Drawable> {
const graph = getGenerateEsqlAgent({
esClient: {} as unknown as ElasticsearchClient,
connectorId: 'test-connector-id',
inference: {} as unknown as InferenceServerStart,
logger,
request: {} as unknown as KibanaRequest,
createLlmInstance: () =>
({ bindTools: () => null } as unknown as
| ActionsClientChatBedrockConverse
| ActionsClientChatVertexAI
| ActionsClientChatOpenAI),
});
return graph.getGraphAsync({ xray: true });
}
export const drawGraph = async ({
getGraphAsync,
outputFilename,
@ -61,6 +82,7 @@ export const drawGraph = async ({
logger.info('Compiling graph');
const outputPath = path.join(__dirname, outputFilename);
const graph = await getGraphAsync(logger);
logger.info('Drawing graph');
const output = await graph.drawMermaidPng();
const buffer = Buffer.from(await output.arrayBuffer());
logger.info(`Writing graph to ${outputPath}`);
@ -69,7 +91,11 @@ export const drawGraph = async ({
export const draw = async () => {
await drawGraph({
getGraphAsync: getAgentGraph,
getGraphAsync: getGenerateEsqlGraph,
outputFilename: '../../docs/generate_esql/img/generate_esql_graph.png',
});
await drawGraph({
getGraphAsync: getSiemMigrationGraph,
outputFilename: '../../docs/siem_migration/img/agent_graph.png',
});
};

View file

@ -0,0 +1,90 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { tool } from '@langchain/core/tools';
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
import { z } from '@kbn/zod';
import { lastValueFrom } from 'rxjs';
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
import type { ElasticAssistantApiRequestHandlerContext } from '@kbn/elastic-assistant-plugin/server/types';
import { APP_UI_ID } from '../../../../common';
import { getPromptSuffixForOssModel } from './utils/common';
export type ESQLToolParams = AssistantToolParams & {
assistantContext: ElasticAssistantApiRequestHandlerContext;
};
const TOOL_NAME = 'AskAboutEsqlTool';
const toolDetails = {
id: 'ask-about-esql-tool',
name: TOOL_NAME,
// note: this description is overwritten when `getTool` is called
// local definitions exist ../elastic_assistant/server/lib/prompt/tool_prompts.ts
// local definitions can be overwritten by security-ai-prompt integration definitions
description: `You MUST use the "${TOOL_NAME}" function when the user:
- asks for help with ES|QL
- asks about ES|QL syntax
- asks for ES|QL examples
- asks for ES|QL documentation
- asks for ES|QL best practices
- asks for ES|QL optimization
Never use this tool when the user wants to generate a ES|QL for their data.`,
};
export const ASK_ABOUT_ESQL_TOOL: AssistantTool = {
...toolDetails,
sourceRegister: APP_UI_ID,
isSupported: (params: AssistantToolParams): params is ESQLToolParams => {
const { inference, connectorId, assistantContext } = params;
return (
inference != null &&
connectorId != null &&
assistantContext != null &&
assistantContext.getRegisteredFeatures('securitySolutionUI').advancedEsqlGeneration
);
},
getTool(params: AssistantToolParams) {
if (!this.isSupported(params)) return null;
const { connectorId, inference, logger, request, isOssModel } = params as ESQLToolParams;
if (inference == null || connectorId == null) return null;
const callNaturalLanguageToEsql = async (question: string) => {
return lastValueFrom(
naturalLanguageToEsql({
client: inference.getClient({ request }),
connectorId,
input: question,
functionCalling: 'auto',
logger,
})
);
};
return tool(
async (input) => {
const generateEvent = await callNaturalLanguageToEsql(input.question);
const answer = generateEvent.content ?? 'An error occurred in the tool';
logger.debug(`Received response from NL to ESQL tool: ${answer}`);
return answer;
},
{
name: toolDetails.name,
description:
(params.description || toolDetails.description) +
(isOssModel ? getPromptSuffixForOssModel(TOOL_NAME) : ''),
schema: z.object({
question: z.string().describe(`The user's exact question about ESQL`),
}),
tags: ['esql', 'query-generation', 'knowledge-base'],
}
);
},
};

View file

@ -1,15 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export const getPromptSuffixForOssModel = (toolName: string) => `
When using ${toolName} tool ALWAYS pass the user's questions directly as input into the tool.
Always return value from ${toolName} tool as is.
The ES|QL query should ALWAYS be wrapped in triple backticks ("\`\`\`esql"). Add a new line character right before the triple backticks.
It is important that ES|QL query is preceeded by a new line.`;

View file

@ -0,0 +1,105 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { tool } from '@langchain/core/tools';
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
import { z } from '@kbn/zod';
import { HumanMessage } from '@langchain/core/messages';
import type { ElasticAssistantApiRequestHandlerContext } from '@kbn/elastic-assistant-plugin/server/types';
import type {
ActionsClientChatBedrockConverse,
ActionsClientChatVertexAI,
ActionsClientChatOpenAI,
} from '@kbn/langchain/server';
import { APP_UI_ID } from '../../../../common';
import { getPromptSuffixForOssModel } from './utils/common';
import { getGenerateEsqlGraph } from './graphs/generate_esql/generate_esql';
export type GenerateEsqlParams = AssistantToolParams & {
assistantContext: ElasticAssistantApiRequestHandlerContext;
createLlmInstance: () =>
| ActionsClientChatBedrockConverse
| ActionsClientChatVertexAI
| ActionsClientChatOpenAI;
};
const TOOL_NAME = 'GenerateESQLTool';
const toolDetails = {
id: 'gnerate-esql-tool',
name: TOOL_NAME,
// note: this description is overwritten when `getTool` is called
// local definitions exist ../elastic_assistant/server/lib/prompt/tool_prompts.ts
// local definitions can be overwritten by security-ai-prompt integration definitions
description: `You MUST use the "${TOOL_NAME}" function when the user wants to:
- generate an ES|QL query
- convert queries from another language to ES|QL they can run on their cluster
ALWAYS use this tool to generate ES|QL queries and never generate ES|QL any other way.`,
};
export const GENERATE_ESQL_TOOL: AssistantTool = {
...toolDetails,
sourceRegister: APP_UI_ID,
isSupported: (params: AssistantToolParams): params is GenerateEsqlParams => {
const { inference, connectorId, assistantContext, createLlmInstance } = params;
return (
inference != null &&
connectorId != null &&
assistantContext != null &&
assistantContext.getRegisteredFeatures('securitySolutionUI').advancedEsqlGeneration &&
createLlmInstance != null
);
},
getTool(params: AssistantToolParams) {
if (!this.isSupported(params)) return null;
const { connectorId, inference, logger, request, isOssModel, esClient, createLlmInstance } =
params as GenerateEsqlParams;
if (inference == null || connectorId == null) return null;
const selfHealingGraph = getGenerateEsqlGraph({
esClient,
connectorId,
inference,
logger,
request,
createLlmInstance,
});
return tool(
async ({ question }) => {
const result = await selfHealingGraph.invoke(
{
messages: [new HumanMessage({ content: question })],
input: { question },
},
{ recursionLimit: 30 }
);
const { messages } = result;
const lastMessage = messages[messages.length - 1];
return lastMessage.content;
},
{
name: toolDetails.name,
description:
(params.description || toolDetails.description) +
(isOssModel ? getPromptSuffixForOssModel(TOOL_NAME) : ''),
schema: z.object({
question: z
.string()
.describe(
`The user's exact question about ES|QL. Provide as much detail as possible including the name of the index and fields if the user has provided those.`
),
}),
tags: ['esql', 'query-generation', 'knowledge-base'],
}
);
},
};

View file

@ -0,0 +1,99 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { END, START, StateGraph } from '@langchain/langgraph';
import { ToolNode } from '@langchain/langgraph/prebuilt';
import type { ElasticsearchClient } from '@kbn/core/server';
import { AnalyzeIndexPatternAnnotation } from './state';
import {
GET_FIELD_DESCRIPTORS,
EXPLORE_PARTIAL_INDEX_RESPONDER,
EXPLORE_PARTIAL_INDEX_AGENT,
ANALYZE_COMPRESSED_INDEX_MAPPING_AGENT,
TOOLS,
} from './constants';
import type { CreateLlmInstance } from '../../utils/common';
import { messageContainsToolCalls } from '../../utils/common';
import { getInspectIndexMappingTool } from '../../tools/inspect_index_mapping_tool/inspect_index_mapping_tool';
import { getFieldDescriptors } from './nodes/get_field_descriptors/get_field_descriptors';
import { getAnalyzeCompressedIndexMappingAgent } from './nodes/analyze_compressed_index_mapping_agent/analyze_compressed_index_mapping_agent';
import { getExplorePartialIndexMappingAgent } from './nodes/explore_partial_index_mapping_agent/explore_partial_index_mapping_agent';
import { getExplorePartialIndexMappingResponder } from './nodes/explore_partial_index_mapping_responder/explore_partial_index_mapping_responder';
export const getAnalyzeIndexPatternGraph = ({
esClient,
createLlmInstance,
}: {
esClient: ElasticsearchClient;
createLlmInstance: CreateLlmInstance;
}) => {
const graph = new StateGraph(AnalyzeIndexPatternAnnotation)
.addNode(GET_FIELD_DESCRIPTORS, getFieldDescriptors({ esClient }))
.addNode(
ANALYZE_COMPRESSED_INDEX_MAPPING_AGENT,
getAnalyzeCompressedIndexMappingAgent({ createLlmInstance })
)
.addNode(
EXPLORE_PARTIAL_INDEX_AGENT,
getExplorePartialIndexMappingAgent({ esClient, createLlmInstance })
)
.addNode(TOOLS, (state: typeof AnalyzeIndexPatternAnnotation.State) => {
const { input } = state;
if (input === undefined) {
throw new Error('Input is required');
}
const inspectIndexMappingTool = getInspectIndexMappingTool({
esClient,
indexPattern: input.indexPattern,
});
const tools = [inspectIndexMappingTool];
const toolNode = new ToolNode(tools);
return toolNode.invoke(state);
})
.addNode(
EXPLORE_PARTIAL_INDEX_RESPONDER,
getExplorePartialIndexMappingResponder({ createLlmInstance })
)
.addEdge(START, GET_FIELD_DESCRIPTORS)
.addConditionalEdges(
GET_FIELD_DESCRIPTORS,
(state: typeof AnalyzeIndexPatternAnnotation.State) => {
if (state.fieldDescriptors === undefined) {
throw new Error('Expected field descriptors to be defined');
}
return state.fieldDescriptors.length > 2500
? EXPLORE_PARTIAL_INDEX_AGENT
: ANALYZE_COMPRESSED_INDEX_MAPPING_AGENT;
},
{
[EXPLORE_PARTIAL_INDEX_AGENT]: EXPLORE_PARTIAL_INDEX_AGENT,
[ANALYZE_COMPRESSED_INDEX_MAPPING_AGENT]: ANALYZE_COMPRESSED_INDEX_MAPPING_AGENT,
}
)
.addEdge(ANALYZE_COMPRESSED_INDEX_MAPPING_AGENT, END)
.addConditionalEdges(
EXPLORE_PARTIAL_INDEX_AGENT,
(state: typeof AnalyzeIndexPatternAnnotation.State) => {
if (messageContainsToolCalls(state.messages[state.messages.length - 1])) {
return TOOLS;
}
return EXPLORE_PARTIAL_INDEX_RESPONDER;
},
{
[TOOLS]: TOOLS,
[EXPLORE_PARTIAL_INDEX_RESPONDER]: EXPLORE_PARTIAL_INDEX_RESPONDER,
}
)
.addEdge(TOOLS, EXPLORE_PARTIAL_INDEX_AGENT)
.addEdge(EXPLORE_PARTIAL_INDEX_RESPONDER, END)
.compile();
return graph;
};

View file

@ -0,0 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export const GET_FIELD_DESCRIPTORS = 'getFieldDescriptors';
export const EXPLORE_PARTIAL_INDEX_AGENT = 'explorePartialIndexAgent';
export const EXPLORE_PARTIAL_INDEX_RESPONDER = 'explorePartialIndexResponder';
export const ANALYZE_COMPRESSED_INDEX_MAPPING_AGENT = 'analyzeCompressedIndexMappingAgent';
export const TOOLS = 'tools';

View file

@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { z } from '@kbn/zod';
import { Command } from '@langchain/langgraph';
import { HumanMessage, SystemMessage } from '@langchain/core/messages';
import { mapFieldDescriptorToNestedObject } from '../../../../tools/inspect_index_mapping_tool/inspect_index_utils';
import type { CreateLlmInstance } from '../../../../utils/common';
import type { AnalyzeIndexPatternAnnotation } from '../../state';
import { compressMapping } from './compress_mapping';
const structuredOutput = z.object({
containsRequiredFieldsForQuery: z
.boolean()
.describe('Whether the index pattern contains the required fields for the query'),
});
export const getAnalyzeCompressedIndexMappingAgent = ({
createLlmInstance,
}: {
createLlmInstance: CreateLlmInstance;
}) => {
const llm = createLlmInstance();
return async (state: typeof AnalyzeIndexPatternAnnotation.State) => {
const { fieldDescriptors, input } = state;
if (fieldDescriptors === undefined) {
throw new Error('State fieldDescriptors is undefined');
}
if (input === undefined) {
throw new Error('State input is undefined');
}
const prunedFields = fieldDescriptors.map((fieldDescriptor) => ({
name: fieldDescriptor.name,
type: fieldDescriptor.esTypes[0],
}));
const nestedObject = mapFieldDescriptorToNestedObject(prunedFields);
const compressedIndexMapping = compressMapping(nestedObject);
const result = await llm
.withStructuredOutput(structuredOutput, { name: 'indexMappingAnalysis' })
.invoke([
new SystemMessage({
content:
'You are a security analyst who is an expert in Elasticsearch and particularly at analyzing indices. ' +
'You will be given an compressed index mapping containing available fields and types and an explanation ' +
'of the query that we are trying to generate. Analyze the index mapping and determine whether it contains the ' +
'fields required to write the query. You do not need to generate the query right now, just determine whether the' +
' index mapping contains the fields required to write the query.',
}),
new HumanMessage({
content: `Query objective:\n'${input.question}'\n\nIndex pattern:\n'${input.indexPattern}'\n\nCompressed index mapping:\n${compressedIndexMapping}`,
}),
]);
return new Command({
update: {
output: {
containsRequiredFieldsForQuery: result.containsRequiredFieldsForQuery,
context: compressedIndexMapping,
},
},
});
};
};

View file

@ -0,0 +1,132 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
// Define types for the input object structure
interface TypedProperty {
type: string;
}
interface NestedObject {
[key: string]: TypedProperty | NestedObject;
}
// Type for stack entries
interface StackEntry {
obj: NestedObject;
result: string | null;
processed: boolean;
}
function hasTypeProperty(obj: TypedProperty | NestedObject): obj is TypedProperty {
return 'type' in obj && typeof obj.type === 'string';
}
export function compressMapping(obj: NestedObject): string {
// Group top-level keys by type
const typeGroups: Record<string, string[]> = {};
const nestedObjs: Array<[string, NestedObject]> = [];
// Organize properties by type or as nested objects
Object.entries(obj).forEach(([key, value]) => {
if (hasTypeProperty(value)) {
const type = value.type;
typeGroups[type] = typeGroups[type] || [];
typeGroups[type].push(key);
} else {
nestedObjs.push([key, value]);
}
});
// Format the result
const parts: string[] = [];
// Add grouped simple properties
for (const [type, keys] of Object.entries(typeGroups)) {
parts.push(`${keys.join(',')}:${type}`);
}
// Add nested objects
for (const [key, nestedObj] of nestedObjs) {
parts.push(`${key}:{${formatNestedObject(nestedObj)}}`);
}
return parts.join('\n');
}
function formatNestedObject(obj: NestedObject): string {
// Stack to track objects that need processing
const stack: StackEntry[] = [{ obj, result: null, processed: false }];
// Cache for already processed objects to avoid cycles
const cache = new Map<NestedObject, string>();
while (stack.length > 0) {
const current = stack[stack.length - 1];
/* eslint-disable no-continue */
// If already processed this object, pop it and continue
if (current.processed) {
stack.pop();
continue;
}
// If we've seen this object before, use cached result
const cachedValue = cache.get(current.obj);
if (cachedValue) {
current.result = cachedValue;
stack.pop();
continue;
}
/* eslint-enable no-continue */
// Group properties by type
const typeGroups: Record<string, string[]> = {};
const nestedProps: Array<[string, NestedObject]> = [];
for (const [propKey, propValue] of Object.entries(current.obj)) {
if (hasTypeProperty(propValue)) {
// Group by type
const type = propValue.type;
typeGroups[type] = typeGroups[type] || [];
typeGroups[type].push(propKey);
} else {
// Track nested objects
nestedProps.push([propKey, propValue]);
}
}
// Process all nested objects first
let allNestedProcessed = true;
const formattedParts: string[] = [];
// Add type groups to formatted parts
for (const [type, keys] of Object.entries(typeGroups)) {
formattedParts.push(keys.length === 1 ? `${keys[0]}:${type}` : `${keys.join(',')}:${type}`);
}
// Process nested objects
for (const [propKey, propValue] of nestedProps) {
// If we have a cached result, use it
if (cache.has(propValue)) {
formattedParts.push(`${propKey}:{${cache.get(propValue)}}`);
} else {
// Push to stack for processing
stack.push({ obj: propValue, result: null, processed: false });
allNestedProcessed = false;
}
}
// If all nested objects are processed, finalize this object
if (allNestedProcessed) {
current.result = formattedParts.join(',');
cache.set(current.obj, current.result);
current.processed = true;
}
}
// Return result from the first item we pushed to stack (root object)
return cache.get(obj) || '';
}

View file

@ -0,0 +1,55 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Command } from '@langchain/langgraph';
import { HumanMessage, SystemMessage } from '@langchain/core/messages';
import type { ElasticsearchClient } from '@kbn/core/server';
import type { CreateLlmInstance } from '../../../../utils/common';
import type { AnalyzeIndexPatternAnnotation } from '../../state';
import { getInspectIndexMappingTool } from '../../../../tools/inspect_index_mapping_tool/inspect_index_mapping_tool';
export const getExplorePartialIndexMappingAgent = ({
createLlmInstance,
esClient,
}: {
createLlmInstance: CreateLlmInstance;
esClient: ElasticsearchClient;
}) => {
const llm = createLlmInstance();
const tool = getInspectIndexMappingTool({
esClient,
indexPattern: 'placeholder',
});
const llmWithTools = llm.bindTools([tool]);
return async (state: typeof AnalyzeIndexPatternAnnotation.State) => {
const { messages, input } = state;
if (input === undefined) {
throw new Error('Input is required');
}
const result = await llmWithTools.invoke([
new SystemMessage({
content:
'You are an expert in Elastic Search and particularly at analyzing indices. You have been given a function that allows you' +
' to explore a large index mapping. Use this function to explore the index mapping and determine whether it contains the fields ' +
'required to write the query.',
}),
new HumanMessage({
content: `Does the index mapping contain the fields required to generate a query that does the following:\n${input.question}`,
}),
...messages,
]);
return new Command({
update: {
messages: [result],
},
});
};
};

View file

@ -0,0 +1,53 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { z } from '@kbn/zod';
import { Command } from '@langchain/langgraph';
import { HumanMessage, SystemMessage } from '@langchain/core/messages';
import type { CreateLlmInstance } from '../../../../utils/common';
import type { AnalyzeIndexPatternAnnotation } from '../../state';
import { buildContext } from './utils';
const structuredOutput = z.object({
containsRequiredFieldsForQuery: z
.boolean()
.describe('Whether the index pattern contains the required fields for the query'),
});
export const getExplorePartialIndexMappingResponder = ({
createLlmInstance,
}: {
createLlmInstance: CreateLlmInstance;
}) => {
const llm = createLlmInstance();
return async (state: typeof AnalyzeIndexPatternAnnotation.State) => {
const { messages } = state;
const lastMessage = messages[messages.length - 1];
const result = await llm
.withStructuredOutput(structuredOutput, { name: 'indexMappingAnalysis' })
.invoke([
new SystemMessage({
content:
'You are an expert at parsing text. You have been given a text and need to parse it into the provided schema.',
}),
new HumanMessage({
content: lastMessage.content,
}),
]);
return new Command({
update: {
output: {
containsRequiredFieldsForQuery: result.containsRequiredFieldsForQuery,
context: JSON.stringify(buildContext(messages)),
},
},
});
};
};

View file

@ -0,0 +1,73 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { AIMessage, BaseMessage } from '@langchain/core/messages';
import { ToolMessage } from '@langchain/core/messages';
import type { ToolCall } from '@langchain/core/dist/messages/tool';
import { set } from '@kbn/safer-lodash-set';
import { toolDetails } from '../../../../tools/inspect_index_mapping_tool/inspect_index_mapping_tool';
import { messageContainsToolCalls } from '../../../../utils/common';
export const buildContext = (messages: BaseMessage[]): Record<string, unknown> => {
const orderedInspectIndexMappingToolCalls: ToolCall[] = messages
.filter((message) => messageContainsToolCalls(message))
.flatMap((message) => (message as AIMessage).tool_calls)
.filter((toolCall) => toolCall !== undefined)
.filter((toolCall) => (toolCall as ToolCall).name === toolDetails.name) as ToolCall[];
const orderedInspectIndexMappingToolCallIds = orderedInspectIndexMappingToolCalls.map(
(toolCall) => toolCall.id
);
const inspectIndexMappingToolCallByIds = orderedInspectIndexMappingToolCalls.reduce(
(acc, toolCall) => {
const toolCallId = toolCall.id;
if (toolCallId !== undefined) {
acc[toolCallId] = toolCall;
}
return acc;
},
{} as Record<string, ToolCall>
);
const orderedInspectIndexMappingToolMessages = messages
.filter((message) => message instanceof ToolMessage)
.filter((message) => (message as ToolMessage).tool_call_id in inspectIndexMappingToolCallByIds)
.map((message) => message as ToolMessage)
.sort(
(a, b) =>
orderedInspectIndexMappingToolCallIds.indexOf(a.tool_call_id) -
orderedInspectIndexMappingToolCallIds.indexOf(b.tool_call_id)
);
let context = {};
/* eslint-disable no-continue */
for (const toolMessage of orderedInspectIndexMappingToolMessages) {
const toolCall = inspectIndexMappingToolCallByIds[toolMessage.tool_call_id];
if (toolCall.args.property !== undefined) {
if (toolCall.args.property === '') {
try {
context = JSON.parse(toolMessage.content as string);
} catch (e) {
continue;
}
} else {
try {
const parsedContent = JSON.parse(toolMessage.content as string);
set(context, toolCall.args.property, parsedContent);
} catch (e) {
continue;
}
}
}
}
/* eslint-enable no-continue */
return context;
};

View file

@ -0,0 +1,36 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { IndexPatternsFetcher } from '@kbn/data-plugin/server';
import type { ElasticsearchClient } from '@kbn/core/server';
import { Command } from '@langchain/langgraph';
import type { AnalyzeIndexPatternAnnotation } from '../../state';
export const getFieldDescriptors = ({ esClient }: { esClient: ElasticsearchClient }) => {
const indexPatternsFetcher = new IndexPatternsFetcher(esClient);
return async (state: typeof AnalyzeIndexPatternAnnotation.State) => {
if (state.input === undefined) {
throw new Error('State input is undefined');
}
const { indexPattern } = state.input;
const { fields: fieldDescriptors } = await indexPatternsFetcher.getFieldsForWildcard({
pattern: indexPattern,
fieldCapsOptions: {
allow_no_indices: false,
includeUnmapped: false,
},
});
return new Command({
update: {
fieldDescriptors,
},
});
};
};

View file

@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { FieldDescriptor } from '@kbn/data-views-plugin/server';
import type { BaseMessage } from '@langchain/core/messages';
import { Annotation, messagesStateReducer } from '@langchain/langgraph';
export const AnalyzeIndexPatternAnnotation = Annotation.Root({
messages: Annotation<BaseMessage[]>({
reducer: messagesStateReducer,
default: () => [],
}),
input: Annotation<{ question: string; indexPattern: string } | undefined>({
reducer: (currentValue, newValue) => newValue ?? currentValue,
default: () => undefined,
}),
fieldDescriptors: Annotation<FieldDescriptor[] | undefined>({
reducer: (currentValue, newValue) => newValue ?? currentValue,
default: () => undefined,
}),
output: Annotation<{ containsRequiredFieldsForQuery: boolean; context: string } | undefined>({
reducer: (currentValue, newValue) => newValue ?? currentValue,
default: () => undefined,
}),
});

View file

@ -0,0 +1,16 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export const BUILD_ERROR_REPORT_FROM_LAST_MESSAGE_NODE = 'buildErrorReportFromLastMessage';
export const BUILD_SUCCESS_REPORT_FROM_LAST_MESSAGE_NODE = 'buildSuccessReportFromLastMessage';
export const BUILD_UNVALIDATED_REPORT_FROM_LAST_MESSAGE_NODE =
'buildUnvalidatedReportFromLastMessage';
export const NL_TO_ESQL_AGENT_WITHOUT_VALIDATION_NODE = 'nlToEsqlAgentWithoutValidation';
export const VALIDATE_ESQL_FROM_LAST_MESSAGE_NODE = 'validateEsqlFromLastMessageNode';
export const SELECT_INDEX_PATTERN_GRAPH = 'selectIndexPatternGraph';
export const NL_TO_ESQL_AGENT_NODE = 'nlToEsqlAgent';
export const TOOLS_NODE = 'toolsNode';

View file

@ -0,0 +1,160 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { ElasticsearchClient, KibanaRequest, Logger } from '@kbn/core/server';
import { END, START, StateGraph } from '@langchain/langgraph';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import type {
ActionsClientChatBedrockConverse,
ActionsClientChatVertexAI,
ActionsClientChatOpenAI,
} from '@kbn/langchain/server';
import { ToolNode } from '@langchain/langgraph/prebuilt';
import { GenerateEsqlAnnotation } from './state';
import {
nlToEsqlAgentStepRouter,
selectIndexStepRouter,
validateEsqlFromLastMessageStepRouter,
} from './step_router';
import { getNlToEsqlAgent } from './nodes/nl_to_esql_agent/nl_to_esql_agent';
import { getValidateEsqlInLastMessageNode } from './nodes/validate_esql_in_last_message_node/validate_esql_in_last_message_node';
import { getInspectIndexMappingTool } from '../../tools/inspect_index_mapping_tool/inspect_index_mapping_tool';
import {
BUILD_ERROR_REPORT_FROM_LAST_MESSAGE_NODE,
BUILD_SUCCESS_REPORT_FROM_LAST_MESSAGE_NODE,
BUILD_UNVALIDATED_REPORT_FROM_LAST_MESSAGE_NODE,
NL_TO_ESQL_AGENT_NODE,
NL_TO_ESQL_AGENT_WITHOUT_VALIDATION_NODE,
SELECT_INDEX_PATTERN_GRAPH,
TOOLS_NODE,
VALIDATE_ESQL_FROM_LAST_MESSAGE_NODE,
} from './constants';
import { getBuildErrorReportFromLastMessageNode } from './nodes/build_error_report_from_last_message/build_error_report_from_last_message';
import { getBuildSuccessReportFromLastMessageNode } from './nodes/build_success_report_from_last_message/build_success_report_from_last_message';
import { getNlToEsqlAgentWithoutValidation } from './nodes/nl_to_esql_agent_without_validation/nl_to_esql_agent_without_validation';
import { getBuildUnvalidatedReportFromLastMessageNode } from './nodes/build_unvalidated_report_from_last_message/build_unvalidated_report_from_last_message';
import { getSelectIndexPattern } from './nodes/select_index_pattern/select_index_pattern';
import { getSelectIndexPatternGraph } from '../select_index_pattern/select_index_pattern';
export const getGenerateEsqlGraph = ({
esClient,
connectorId,
inference,
logger,
request,
createLlmInstance,
}: {
esClient: ElasticsearchClient;
connectorId: string;
inference: InferenceServerStart;
logger: Logger;
request: KibanaRequest;
createLlmInstance: () =>
| ActionsClientChatBedrockConverse
| ActionsClientChatVertexAI
| ActionsClientChatOpenAI;
}) => {
const nlToEsqlAgentNode = getNlToEsqlAgent({
connectorId,
inference,
logger,
request,
tools: [
getInspectIndexMappingTool({
esClient,
indexPattern: 'placeholder',
}),
],
});
const nlToEsqlAgentWithoutValidationNode = getNlToEsqlAgentWithoutValidation({
connectorId,
inference,
logger,
request,
});
const validateEsqlInLastMessageNode = getValidateEsqlInLastMessageNode({
esClient,
});
const buildErrorReportFromLastMessageNode = getBuildErrorReportFromLastMessageNode();
const buildSuccessReportFromLastMessageNode = getBuildSuccessReportFromLastMessageNode();
const buildUnvalidatedReportFromLastMessageNode = getBuildUnvalidatedReportFromLastMessageNode();
const identifyIndexGraph = getSelectIndexPatternGraph({
esClient,
createLlmInstance,
});
const selectIndexPatternSubGraph = getSelectIndexPattern({
identifyIndexGraph,
});
const graph = new StateGraph(GenerateEsqlAnnotation)
// Nodes
.addNode(SELECT_INDEX_PATTERN_GRAPH, selectIndexPatternSubGraph, {
subgraphs: [identifyIndexGraph],
})
.addNode(NL_TO_ESQL_AGENT_NODE, nlToEsqlAgentNode, { retryPolicy: { maxAttempts: 3 } })
.addNode(TOOLS_NODE, (state: typeof GenerateEsqlAnnotation.State) => {
const { selectedIndexPattern } = state;
if (selectedIndexPattern == null) {
throw new Error('Input is required');
}
const inspectIndexMappingTool = getInspectIndexMappingTool({
esClient,
indexPattern: selectedIndexPattern,
});
const tools = [inspectIndexMappingTool];
const toolNode = new ToolNode(tools);
return toolNode.invoke(state);
})
.addNode(VALIDATE_ESQL_FROM_LAST_MESSAGE_NODE, validateEsqlInLastMessageNode)
.addNode(BUILD_SUCCESS_REPORT_FROM_LAST_MESSAGE_NODE, buildSuccessReportFromLastMessageNode)
.addNode(BUILD_ERROR_REPORT_FROM_LAST_MESSAGE_NODE, buildErrorReportFromLastMessageNode)
.addNode(NL_TO_ESQL_AGENT_WITHOUT_VALIDATION_NODE, nlToEsqlAgentWithoutValidationNode)
.addNode(
BUILD_UNVALIDATED_REPORT_FROM_LAST_MESSAGE_NODE,
buildUnvalidatedReportFromLastMessageNode
)
// Edges
.addEdge(START, SELECT_INDEX_PATTERN_GRAPH)
.addConditionalEdges(SELECT_INDEX_PATTERN_GRAPH, selectIndexStepRouter, {
[NL_TO_ESQL_AGENT_NODE]: NL_TO_ESQL_AGENT_NODE,
[END]: END,
})
.addConditionalEdges(NL_TO_ESQL_AGENT_NODE, nlToEsqlAgentStepRouter, {
[VALIDATE_ESQL_FROM_LAST_MESSAGE_NODE]: VALIDATE_ESQL_FROM_LAST_MESSAGE_NODE,
[TOOLS_NODE]: TOOLS_NODE,
})
.addConditionalEdges(
VALIDATE_ESQL_FROM_LAST_MESSAGE_NODE,
validateEsqlFromLastMessageStepRouter,
{
[BUILD_SUCCESS_REPORT_FROM_LAST_MESSAGE_NODE]: BUILD_SUCCESS_REPORT_FROM_LAST_MESSAGE_NODE,
[BUILD_ERROR_REPORT_FROM_LAST_MESSAGE_NODE]: BUILD_ERROR_REPORT_FROM_LAST_MESSAGE_NODE,
[NL_TO_ESQL_AGENT_WITHOUT_VALIDATION_NODE]: NL_TO_ESQL_AGENT_WITHOUT_VALIDATION_NODE,
}
)
.addEdge(TOOLS_NODE, NL_TO_ESQL_AGENT_NODE)
.addEdge(BUILD_ERROR_REPORT_FROM_LAST_MESSAGE_NODE, NL_TO_ESQL_AGENT_NODE)
.addEdge(BUILD_SUCCESS_REPORT_FROM_LAST_MESSAGE_NODE, END)
.addEdge(
NL_TO_ESQL_AGENT_WITHOUT_VALIDATION_NODE,
BUILD_UNVALIDATED_REPORT_FROM_LAST_MESSAGE_NODE
)
.addEdge(BUILD_UNVALIDATED_REPORT_FROM_LAST_MESSAGE_NODE, END)
.compile();
return graph;
};

View file

@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Command } from '@langchain/langgraph';
import type { GenerateEsqlAnnotation } from '../../state';
import { lastMessageWithErrorReport } from './utils';
export const getBuildErrorReportFromLastMessageNode = () => {
return async (state: typeof GenerateEsqlAnnotation.State) => {
const { messages, validateEsqlResults } = state;
const lastMessage = messages[messages.length - 1];
const containsInvalidQueries = validateEsqlResults.some((result) => !result.isValid);
if (!containsInvalidQueries) {
throw new Error('Expected at least one invalid query to be present in the last message');
}
return new Command({
update: {
messages: [lastMessageWithErrorReport(lastMessage.content as string, validateEsqlResults)],
},
});
};
};

View file

@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { EditorError } from '@kbn/esql-ast';
import type { ValidateEsqlResult } from '../validate_esql_in_last_message_node/utils';
import { lastMessageWithErrorReport } from './utils';
describe('esql self healing validator utils', () => {
it('with errors', () => {
const query = 'FROM .logs\n| LIMIT 100';
const message = `Here is the ESQL query to fetch 100 documents from the .logs index
\`\`\`esql
${query}
\`\`\``;
const validateEsqlResult: ValidateEsqlResult = {
query,
isValid: false,
parsingErrors: [
{ message: 'Syntax error', startLineNumber: 1, startColumn: 10 } as EditorError,
],
executionError: new Error('Unknown index .logs'),
};
const result = lastMessageWithErrorReport(message, [validateEsqlResult]);
expect(result.content).toContain(message);
expect(result.content).toContain(
'The above query has the following errors that still need to be fixed'
);
expect(result.content)
.toEqual(`Here is the ESQL query to fetch 100 documents from the .logs index
\`\`\`esql
FROM .logs
| LIMIT 100
\`\`\`
The above query has the following errors that still need to be fixed:
1:10 Syntax error
Unknown index .logs
`);
});
it('without errors', () => {
const query = 'FROM .logs\n| LIMIT 100';
const message = `Here is the ESQL query to fetch 100 documents from the .logs index
\`\`\`esql
${query}
\`\`\``;
const validateEsqlResult: ValidateEsqlResult = {
query,
isValid: true,
};
const result = lastMessageWithErrorReport(message, [validateEsqlResult]);
expect(result.content).toContain(message);
expect(result.content).toContain('Query is valid');
});
});

View file

@ -0,0 +1,64 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { BaseMessage } from '@langchain/core/messages';
import { HumanMessage } from '@langchain/core/messages';
import type { ValidateEsqlResult } from '../validate_esql_in_last_message_node/utils';
/**
* Returns the last message with the error report for each query.
*/
export const lastMessageWithErrorReport = (
message: string,
validateEsqlResults: ValidateEsqlResult[]
): BaseMessage => {
let messageWithErrorReport = message;
validateEsqlResults.reverse().forEach((validateEsqlResult) => {
const index = messageWithErrorReport.indexOf(
'```',
messageWithErrorReport.indexOf(validateEsqlResult.query)
);
const errorMessage = formatValidateEsqlResultToHumanReadable(validateEsqlResult);
messageWithErrorReport = `${messageWithErrorReport.slice(
0,
index + 3
)}\n${errorMessage}\n${messageWithErrorReport.slice(index + 3)}`;
});
return new HumanMessage({
content: messageWithErrorReport,
});
};
const formatValidateEsqlResultToHumanReadable = (validateEsqlResult: ValidateEsqlResult) => {
if (validateEsqlResult.isValid) {
return 'Query is valid';
}
let errorMessage = 'The above query has the following errors that still need to be fixed:\n';
if (validateEsqlResult.parsingErrors) {
errorMessage += `${validateEsqlResult.parsingErrors
.map((error) => `${error.startLineNumber}:${error.startColumn} ${error.message}`)
.join('\n')}\n`;
}
if (validateEsqlResult.executionError) {
errorMessage += `${extractErrorMessage(validateEsqlResult.executionError)}\n`;
}
return errorMessage;
};
const extractErrorMessage = (error: unknown): string => {
if (
error &&
typeof error === 'object' &&
'message' in error &&
typeof error.message === 'string'
) {
return error.message;
}
return `Unknown error`;
};

View file

@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Command } from '@langchain/langgraph';
import type { GenerateEsqlAnnotation } from '../../state';
export const getBuildSuccessReportFromLastMessageNode = () => {
return async (state: typeof GenerateEsqlAnnotation.State) => {
const { messages, validateEsqlResults } = state;
const lastMessage = messages[messages.length - 1];
const containsInvalidQueries = validateEsqlResults.some((result) => !result.isValid);
if (containsInvalidQueries) {
throw new Error('Expected all queries to be valid.');
}
return new Command({
update: {
messages: [`${lastMessage.content}\n\nAll queries have been validated.`],
},
});
};
};

View file

@ -0,0 +1,27 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Command } from '@langchain/langgraph';
import type { GenerateEsqlAnnotation } from '../../state';
import { lastMessageWithUnvalidatedReport } from './utils';
export const getBuildUnvalidatedReportFromLastMessageNode = () => {
return async (state: typeof GenerateEsqlAnnotation.State) => {
const { messages } = state;
const lastMessage = messages[messages.length - 1];
return new Command({
update: {
messages: [
`${
lastMessageWithUnvalidatedReport(lastMessage.content as string).content
}\n\n The resulting query was generated as a best effort example, but we are unable to validate it. Please provide the name of the index and fields that should be used in the query. Make sure to include this in the final response`,
],
},
});
};
};

View file

@ -0,0 +1,40 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { lastMessageWithUnvalidatedReport } from './utils';
describe('utils', () => {
it('lastMessageWithUnvalidatedReport', () => {
const response = `
Here is the ES|QL query to retrieve 10 items from the ".data.hello" index:
\`\`\`esql
FROM ".data.hello"
| LIMIT 10
\`\`\`
This query limits the results to 10 items from the specified index. Let me know if you need further assistance!
`;
const result = lastMessageWithUnvalidatedReport(response);
expect(result.content).toEqual(
`
Here is the ES|QL query to retrieve 10 items from the ".data.hello" index:
\`\`\`esql
FROM ".data.hello"
| LIMIT 10
// This query was not validated.
\`\`\`
This query limits the results to 10 items from the specified index. Let me know if you need further assistance!
`
);
});
});

View file

@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { BaseMessage } from '@langchain/core/messages';
import { HumanMessage } from '@langchain/core/messages';
export const lastMessageWithUnvalidatedReport = (lastMessage: string): BaseMessage => {
let result = '';
let startIndex = 0;
while (true) {
const start = lastMessage.indexOf('```esql', startIndex);
if (start === -1) break;
const end = lastMessage.indexOf('```', start + 7);
if (end === -1) break;
result += `${lastMessage.substring(startIndex, end)}\n// This query was not validated.\n\`\`\``;
startIndex = end + 3;
}
result += lastMessage.substring(startIndex);
return new HumanMessage({
content: result,
});
};

View file

@ -0,0 +1,119 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { getNlToEsqlAgent } from './nl_to_esql_agent';
import type { KibanaRequest } from '@kbn/core/server';
import type { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import { loggerMock } from '@kbn/logging-mocks';
import { AIMessage, HumanMessage, ToolMessage } from '@langchain/core/messages';
import { Observable } from 'rxjs';
import type { ChatCompletionMessageEvent } from '@kbn/inference-common';
import { ChatCompletionEventType } from '@kbn/inference-common';
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
import type { GenerateEsqlAnnotation } from '../../state';
jest.mock('@kbn/inference-plugin/server', () => ({
naturalLanguageToEsql: jest.fn(),
}));
describe('nl to esql agent', () => {
const request = {
body: {
isEnabledKnowledgeBase: false,
alertsIndexPattern: '.alerts-security.alerts-default',
allow: ['@timestamp', 'cloud.availability_zone', 'user.name'],
allowReplacement: ['user.name'],
replacements: { key: 'value' },
size: 20,
},
} as unknown as KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
const logger = loggerMock.create();
const inference = {
getClient: jest.fn(),
} as unknown as InferenceServerStart;
const connectorId = 'fake-connector';
const rest = {
logger,
request,
inference,
connectorId,
tools: [],
};
const sampleState: typeof GenerateEsqlAnnotation.State = {
input: {
question: 'test',
},
messages: [],
validateEsqlResults: [],
maximumValidationAttempts: 3,
maximumEsqlGenerationAttempts: 3,
selectedIndexPattern: '',
};
const sampleMessageLog = [
new HumanMessage({
content: 'Human message hello',
}),
new AIMessage({
content: 'AI message hello',
tool_calls: [
{
name: 'sampleTool',
id: '123',
args: {},
},
],
}),
new ToolMessage({
tool_call_id: '123',
content: 'Tool message hello',
}),
];
it('calls naturalLanguageToEsql with the correct parameters', async () => {
const agent = getNlToEsqlAgent({ ...rest });
(naturalLanguageToEsql as unknown as jest.Mock).mockReturnValue(
new Observable((subscriber) => {
const result: ChatCompletionMessageEvent = {
content: 'Hello, World!',
type: ChatCompletionEventType.ChatCompletionMessage,
toolCalls: [],
};
subscriber.next(result);
subscriber.complete();
})
);
const result = await agent({
...sampleState,
messages: sampleMessageLog,
});
expect(naturalLanguageToEsql).toHaveBeenCalledTimes(1);
expect(naturalLanguageToEsql).toHaveBeenCalledWith(
expect.objectContaining({
messages: expect.arrayContaining([
expect.objectContaining({
content: 'Human message hello',
}),
expect.objectContaining({
content: 'AI message hello',
}),
expect.objectContaining({
name: '123',
response: {
response: 'Tool message hello',
},
}),
]),
})
);
expect((result.update as { messages: unknown[] })?.messages[0]).toBeInstanceOf(AIMessage);
});
});

View file

@ -0,0 +1,61 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { lastValueFrom } from 'rxjs';
import type { StructuredToolInterface } from '@langchain/core/tools';
import type { KibanaRequest, Logger } from '@kbn/core/server';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
import type { ChatCompletionMessageEvent } from '@kbn/inference-common';
import { Command } from '@langchain/langgraph';
import { responseToLangchainMessage } from '@kbn/inference-langchain/src/chat_model/from_inference';
import {
messagesToInference,
toolDefinitionToInference,
} from '@kbn/inference-langchain/src/chat_model/to_inference';
import type { GenerateEsqlAnnotation } from '../../state';
export const getNlToEsqlAgent = ({
connectorId,
inference,
logger,
request,
tools,
}: {
connectorId: string;
inference: InferenceServerStart;
logger: Logger;
request: KibanaRequest;
tools: StructuredToolInterface[];
}) => {
return async (state: typeof GenerateEsqlAnnotation.State) => {
const { messages: stateMessages } = state;
const inferenceMessages = messagesToInference(stateMessages);
const result = (await lastValueFrom(
naturalLanguageToEsql({
client: inference.getClient({ request }),
connectorId,
functionCalling: 'auto',
logger,
tools: toolDefinitionToInference(tools),
messages: inferenceMessages.messages,
system: "Just produce the query fenced by the esql tag. Don't explain it.",
})
)) as ChatCompletionMessageEvent;
// const aiMessage = requireFirstInspectIndexMappingCallWithEmptyKey(responseToLangchainMessage(result), stateMessages);
return new Command({
update: {
maximumEsqlGenerationAttempts: state.maximumEsqlGenerationAttempts - 1,
messages: [responseToLangchainMessage(result)],
},
});
};
};

View file

@ -0,0 +1,51 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { lastValueFrom } from 'rxjs';
import type { KibanaRequest, Logger } from '@kbn/core/server';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
import type { ChatCompletionMessageEvent } from '@kbn/inference-common';
import { Command } from '@langchain/langgraph';
import { responseToLangchainMessage } from '@kbn/inference-langchain/src/chat_model/from_inference';
import type { GenerateEsqlAnnotation } from '../../state';
export const getNlToEsqlAgentWithoutValidation = ({
connectorId,
inference,
logger,
request,
}: {
connectorId: string;
inference: InferenceServerStart;
logger: Logger;
request: KibanaRequest;
}) => {
return async (state: typeof GenerateEsqlAnnotation.State) => {
const { input } = state;
if (!input) {
throw new Error('Input is required');
}
const result = (await lastValueFrom(
naturalLanguageToEsql({
client: inference.getClient({ request }),
connectorId,
logger,
input: input.question,
system: "Just produce the query fenced by the esql tag. Don't explain it.",
})
)) as ChatCompletionMessageEvent;
return new Command({
update: {
messages: [responseToLangchainMessage(result)],
},
});
};
};

View file

@ -0,0 +1,48 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Command } from '@langchain/langgraph';
import { HumanMessage } from '@langchain/core/messages';
import type { getSelectIndexPatternGraph } from '../../../select_index_pattern/select_index_pattern';
import type { GenerateEsqlAnnotation } from '../../state';
import { NL_TO_ESQL_AGENT_WITHOUT_VALIDATION_NODE } from '../../constants';
export const getSelectIndexPattern = ({
identifyIndexGraph,
}: {
identifyIndexGraph: ReturnType<typeof getSelectIndexPatternGraph>;
}) => {
return async (state: typeof GenerateEsqlAnnotation.State) => {
const childGraphOutput = await identifyIndexGraph.invoke({
input: state.input,
});
if (!childGraphOutput.selectedIndexPattern) {
return new Command({
goto: NL_TO_ESQL_AGENT_WITHOUT_VALIDATION_NODE,
});
}
const context =
childGraphOutput.selectedIndexPattern in childGraphOutput.indexPatternAnalysis
? childGraphOutput.indexPatternAnalysis[childGraphOutput.selectedIndexPattern].context
: undefined;
return new Command({
update: {
selectedIndexPattern: childGraphOutput.selectedIndexPattern,
messages: [
new HumanMessage({
content:
`We have analyzed multiple index patterns to see if they contain the data required for the query. The following index pattern should be used for the query verbatim: '${childGraphOutput.selectedIndexPattern}'.\n` +
`Some context about the index mapping:\n\n${context ? context : ''}`,
}),
],
},
});
};
};

View file

@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { extractEsqlFromContent } from './utils';
describe('common', () => {
it.each([
['```esqlhelloworld```', ['helloworld']],
['```esqlhelloworld``````esqlhelloworld```', ['helloworld', 'helloworld']],
['```esql\nFROM sample_data```', ['FROM sample_data']],
['```esql\nFROM sample_data\n```', ['FROM sample_data']],
['```esql\nFROM sample_data\n| LIMIT 3\n```', ['FROM sample_data\n| LIMIT 3']],
])('should add %s and %s', (input: string, expectedResult: string[]) => {
expect(extractEsqlFromContent(input)).toEqual(expectedResult);
});
});

View file

@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { ElasticsearchClient } from '@kbn/core/server';
import { parse, type EditorError } from '@kbn/esql-ast';
import { isEmpty } from 'lodash';
export interface ValidateEsqlResult {
isValid: boolean;
query: string;
parsingErrors?: EditorError[];
executionError?: unknown;
}
export const validateEsql = async (
esClient: ElasticsearchClient,
query: string
): Promise<ValidateEsqlResult> => {
const { errors: parsingErrors } = parse(query);
if (!isEmpty(parsingErrors)) {
return {
isValid: false,
query,
parsingErrors,
};
}
try {
await esClient.esql.query({
query: `${query}\n| LIMIT 0`, // Add a LIMIT 0 to minimize the risk of executing a costly query
format: 'json',
});
} catch (executionError) {
return {
isValid: false,
query,
executionError,
};
}
return {
isValid: true,
query,
};
};
export const extractEsqlFromContent = (content: string): string[] => {
const extractedEsql = [];
let index = 0;
while (index < content.length) {
const start = content.indexOf('```esql', index);
if (start === -1) {
break;
}
const end = content.indexOf('```', start + 7);
if (end === -1) {
break;
}
extractedEsql.push(content.slice(start + 7, end).trim());
index = end + 3;
}
return extractedEsql;
};

View file

@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { ElasticsearchClient } from '@kbn/core/server';
import { Command } from '@langchain/langgraph';
import { extractEsqlFromContent, validateEsql } from './utils';
import type { GenerateEsqlAnnotation } from '../../state';
export const getValidateEsqlInLastMessageNode = ({
esClient,
}: {
esClient: ElasticsearchClient;
}) => {
return async (state: typeof GenerateEsqlAnnotation.State) => {
const { messages } = state;
const lastMessage = messages[messages.length - 1];
const generatedQueries = extractEsqlFromContent(lastMessage.content as string);
const validateEsqlResults = await Promise.all(
generatedQueries.map((query) => validateEsql(esClient, query))
);
return new Command({
update: {
maximumValidationAttempts: state.maximumValidationAttempts - 1,
validateEsqlResults,
},
});
};
};

View file

@ -0,0 +1,37 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { BaseMessage } from '@langchain/core/messages';
import { Annotation, messagesStateReducer } from '@langchain/langgraph';
import type { ValidateEsqlResult } from './nodes/validate_esql_in_last_message_node/utils';
export const GenerateEsqlAnnotation = Annotation.Root({
input: Annotation<{ question: string; indexPattern?: string } | undefined>({
reducer: (currentValue, newValue) => newValue ?? currentValue,
default: () => undefined,
}),
messages: Annotation<BaseMessage[]>({
reducer: messagesStateReducer,
default: () => [],
}),
validateEsqlResults: Annotation<ValidateEsqlResult[]>({
reducer: (currentValue, newValue) => newValue ?? currentValue,
default: () => [],
}),
maximumValidationAttempts: Annotation<number>({
reducer: (currentValue, newValue) => newValue ?? currentValue,
default: () => 4,
}),
maximumEsqlGenerationAttempts: Annotation<number>({
reducer: (currentValue, newValue) => newValue ?? currentValue,
default: () => 4,
}),
selectedIndexPattern: Annotation<string | undefined | null>({
reducer: (currentValue, newValue) => (newValue === undefined ? currentValue : newValue),
default: () => undefined,
}),
});

View file

@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { END } from '@langchain/langgraph';
import type { GenerateEsqlAnnotation } from './state';
import {
BUILD_ERROR_REPORT_FROM_LAST_MESSAGE_NODE,
BUILD_SUCCESS_REPORT_FROM_LAST_MESSAGE_NODE,
NL_TO_ESQL_AGENT_NODE,
NL_TO_ESQL_AGENT_WITHOUT_VALIDATION_NODE,
TOOLS_NODE,
VALIDATE_ESQL_FROM_LAST_MESSAGE_NODE,
} from './constants';
import { messageContainsToolCalls } from '../../utils/common';
export const validateEsqlFromLastMessageStepRouter = (
state: typeof GenerateEsqlAnnotation.State
): string => {
const { validateEsqlResults, maximumEsqlGenerationAttempts, maximumValidationAttempts } = state;
const containsInvalidQueries = validateEsqlResults.some(
(validateEsqlResult) => !validateEsqlResult.isValid
);
if (validateEsqlResults.length > 0 && !containsInvalidQueries) {
return BUILD_SUCCESS_REPORT_FROM_LAST_MESSAGE_NODE;
}
if (
validateEsqlResults.length > 0 &&
containsInvalidQueries &&
maximumValidationAttempts > 0 &&
maximumEsqlGenerationAttempts > 0
) {
return BUILD_ERROR_REPORT_FROM_LAST_MESSAGE_NODE;
}
return NL_TO_ESQL_AGENT_WITHOUT_VALIDATION_NODE;
};
export const selectIndexStepRouter = (state: typeof GenerateEsqlAnnotation.State): string => {
const { selectedIndexPattern } = state;
if (selectedIndexPattern == null) {
return END;
}
return NL_TO_ESQL_AGENT_NODE;
};
export const nlToEsqlAgentStepRouter = (state: typeof GenerateEsqlAnnotation.State): string => {
const { messages } = state;
const lastMessage = messages[messages.length - 1];
if (messageContainsToolCalls(lastMessage)) {
return TOOLS_NODE;
}
return VALIDATE_ESQL_FROM_LAST_MESSAGE_NODE;
};

View file

@ -0,0 +1,11 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export const GET_INDEX_PATTERNS = 'getIndexPatterns';
export const SHORTLIST_INDEX_PATTERNS = 'shortlistIndexPatterns';
export const ANALYZE_INDEX_PATTERN = 'analyseIndexPattern';
export const SELECT_INDEX_PATTERN = 'selectIndexPattern';

View file

@ -0,0 +1,45 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Command } from '@langchain/langgraph';
import type { getAnalyzeIndexPatternGraph } from '../../../analyse_index_pattern/analyse_index_pattern';
export const getAnalyzeIndexPattern = ({
analyzeIndexPatternGraph,
}: {
analyzeIndexPatternGraph: ReturnType<typeof getAnalyzeIndexPatternGraph>;
}) => {
return async ({
input,
}: {
input: {
question: string;
indexPattern: string;
};
}) => {
const result = await analyzeIndexPatternGraph.invoke({
input,
});
const { output } = result;
if (output === undefined) {
throw new Error('No output from analyze index pattern graph');
}
return new Command({
update: {
indexPatternAnalysis: {
[input.indexPattern]: {
indexPattern: input.indexPattern,
containsRequiredData: output.containsRequiredFieldsForQuery,
context: output.context,
},
},
},
});
};
};

View file

@ -0,0 +1,76 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Command } from '@langchain/langgraph';
import type { ElasticsearchClient } from '@kbn/core/server';
import type { SelectIndexPatternAnnotation } from '../../state';
import { buildTree, getIndexPatterns } from './utils';
export const fetchIndexPatterns = ({ esClient }: { esClient: ElasticsearchClient }) => {
return async (state: typeof SelectIndexPatternAnnotation.State) => {
const indicesResolveIndexResponse = await esClient.indices.resolveIndex({
name: '*',
expand_wildcards: 'open',
});
// Stores indices that do not have any datastreams or aliases
const indicesWithoutDatastreamsOrAliases = new Set<string>();
const seenIndices = new Set<string>();
const dataStreamsAndAliases = new Set<string>();
for (const dataStream of indicesResolveIndexResponse.data_streams) {
for (const index of dataStream.backing_indices) {
seenIndices.add(index);
}
dataStreamsAndAliases.add(dataStream.name);
}
for (const alias of indicesResolveIndexResponse.aliases) {
for (const index of alias.indices) {
seenIndices.add(index);
}
dataStreamsAndAliases.add(alias.name);
}
// Add indices that do not have any datastreams or aliases
for (const index of indicesResolveIndexResponse.indices) {
if (!seenIndices.has(index.name)) {
indicesWithoutDatastreamsOrAliases.add(index.name);
}
}
const indexNamePartRootNode = buildTree([
...indicesWithoutDatastreamsOrAliases,
...dataStreamsAndAliases,
]);
const constructedIndexPatterns = getIndexPatterns(indexNamePartRootNode, {
ignoreDigitParts: true,
});
const indexPatterns = new Set<string>();
// Add any index patterns that could be constructed from the indices
for (const indexPattern of constructedIndexPatterns.indexPatterns) {
indexPatterns.add(indexPattern);
}
// Add any remaining indices that did not match any patterns
for (const remainingIndex of constructedIndexPatterns.remainingIndices) {
indexPatterns.add(remainingIndex);
}
const availableIndexPatterns = Array.from(indexPatterns).filter(
(indexPattern) => !indexPattern.startsWith('.')
);
return new Command({
update: {
indexPatterns: availableIndexPatterns,
},
});
};
};

View file

@ -0,0 +1,67 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { buildTree, getIndexPatterns } from './utils';
const indices = [
'employees-development.evaluations.2026.01.13',
'employees-production.evaluations.2006.10.01',
'employees-staging.evaluations.2040.09.12',
'logs-development.evaluations.2005.07.06',
'logs-production.evaluations.2002.11.03',
'logs-staging.evaluations.2014.01.15',
'metricbeat-development.evaluations-2036.11.27',
'metricbeat-production.evaluations-2025.03.15',
'metricbeat-staging.evaluations-2001.03.23',
'metrics-apm-development.evaluations.2007.03.23',
'metrics-apm-development.evaluations.2047.08.12',
'metrics-apm-production.evaluations.2009.02.15',
'metrics-apm-production.evaluations.2025.09.27',
'metrics-apm-staging.evaluations.2041.12.06',
'metrics-apm-staging.evaluations.2043.10.20',
'metrics-endpoint.metadata_current_default',
'nyc_taxis-development.evaluations.2020.09.10',
'nyc_taxis-production.evaluations.2007.02.06',
'nyc_taxis-staging.evaluations.2002.08.21',
'packetbeat-development.evaluations.2040.06.19',
'packetbeat-production.evaluations.2047.07.24',
'packetbeat-staging.evaluations.2043.06.26',
'postgres-logs-development.evaluations.2035.11.24',
'postgres-logs-production.evaluations.2050.11.01',
'postgres-logs-staging.evaluations.2034.07.14',
'traces-apm-development.evlauations.2006.01.28',
'traces-apm-development.evlauations.2006.09.09',
'traces-apm-production.evlauations.2016.12.18',
'traces-apm-production.evlauations.2037.08.13',
'traces-apm-staging.evlauations.2028.11.05',
'traces-apm-staging.evlauations.2029.06.14',
'traches-aapm-staging.evlauations.2029.06.14',
];
describe('convertIndicesToIndexPatterns', () => {
it('should convert indices to index patterns', async () => {
const tree = await buildTree(indices);
const result = getIndexPatterns(tree, { ignoreDigitParts: true });
expect(result.indexPatterns).toEqual([
'employees-*',
'logs-*',
'metricbeat-*',
'metrics-*',
'metrics-apm-*',
'metrics-apm-development.evaluations.*',
'metrics-apm-production.evaluations.*',
'metrics-apm-staging.evaluations.*',
'nyc_taxis-*',
'packetbeat-*',
'postgres-logs-*',
'traces-apm-*',
'traces-apm-production.evlauations.*',
'traces-apm-staging.evlauations.*',
]);
expect(result.remainingIndices).toContain('traches-aapm-staging.evlauations.2029.06.14');
});
});

View file

@ -0,0 +1,127 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
/**
* Splits string by '.' and '-' delimiters, including the delimiters in the result.
* For example, "foo.bar-baz" will be split into ["foo.", "bar-", "baz"].
*/
function splitWithDelimiters(str: string): string[] {
const result = [];
let start = 0;
for (let i = 0; i < str.length; i++) {
if (str[i] === '.' || str[i] === '-') {
result.push(str.slice(start, i + 1)); // Include the delimiter
start = i + 1;
}
}
if (start < str.length) {
result.push(str.slice(start)); // Add last part if there's anything left
}
return result;
}
interface IndexNamePartNode {
value: string;
children?: IndexNamePartNode[];
}
/**
* Builds a tree structure from an array of strings.
*/
export const buildTree = (indices: string[]): IndexNamePartNode => {
const splitIndices = indices.map((index) => splitWithDelimiters(index));
const root: IndexNamePartNode = { value: '', children: [] };
for (const splitIndex of splitIndices) {
let currentNode: IndexNamePartNode = root;
for (const part of splitIndex) {
let childNode = currentNode.children?.find((child) => child.value === part);
if (!childNode) {
childNode = { value: part, children: [] };
currentNode.children?.push(childNode);
}
currentNode = childNode;
}
}
return root;
};
interface Options {
ignoreDigitParts?: boolean;
}
/**
* Generates index patterns from a tree structure.
* @param tree The root IndexNamePartNode of the tree structure.
* @param options Options to customize the behavior of the function.
* @returns
*/
export const getIndexPatterns = (
tree: IndexNamePartNode,
options?: Options
): {
indexPatterns: string[];
remainingIndices: string[];
} => {
// This function will traverse the tree and generate index patterns.
const stack: Array<{ node: IndexNamePartNode; prefix: string; indexPatternAdded: boolean }> = [
{ node: tree, prefix: '', indexPatternAdded: false },
];
const indexPatterns: Set<string> = new Set();
const remainingIndices: Set<string> = new Set();
const ignoreDigitParts = Boolean(options?.ignoreDigitParts);
while (stack.length > 0) {
let indexPatternAdded = false;
const next = stack.pop();
if (!next) {
break;
}
const { node, prefix, indexPatternAdded: parentIndexPatternAdded } = next;
if (
node.children &&
node.children.length > 1 &&
node.value !== '' &&
(ignoreDigitParts ? !isNumber(node.value.replace('.', '').replace('-', '')) : true)
) {
// If there are multiple children, we can create a wildcard pattern
indexPatterns.add(`${prefix}${node.value}*`);
indexPatternAdded = true;
}
if (node.children && node.children.length > 0) {
for (const child of node.children) {
stack.push({
node: child,
prefix: `${prefix}${node.value}`,
indexPatternAdded: parentIndexPatternAdded || indexPatternAdded,
});
}
} else {
// If there are no children, we can create a specific index pattern
if (!(parentIndexPatternAdded || indexPatternAdded)) {
remainingIndices.add(`${prefix}${node.value}`);
}
}
}
return {
indexPatterns: [...indexPatterns].filter((pattern) => pattern.includes('*')).sort(),
remainingIndices: [...remainingIndices].sort(),
};
};
function isNumber(str: string): boolean {
return !isNaN(Number(str)) && !isNaN(parseFloat(str));
}

View file

@ -0,0 +1,59 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatBedrockConverse,
ActionsClientChatVertexAI,
ActionsClientChatOpenAI,
} from '@kbn/langchain/server';
import { Command } from '@langchain/langgraph';
import type { SelectIndexPatternAnnotation } from '../../state';
export const getSelectIndexPattern = ({
createLlmInstance,
}: {
createLlmInstance: () =>
| ActionsClientChatBedrockConverse
| ActionsClientChatVertexAI
| ActionsClientChatOpenAI;
}) => {
return async (state: typeof SelectIndexPatternAnnotation.State) => {
const indexPatternAnalysis = Object.values(state.indexPatternAnalysis);
const candidateIndexPatterns = indexPatternAnalysis.filter(
({ containsRequiredData }) => containsRequiredData
);
if (candidateIndexPatterns.length === 0) {
// Non of the analyzed index patterns contained the required data
return new Command({
update: {
selectedIndexPattern: null,
},
});
}
if (candidateIndexPatterns.length === 1) {
// Exactly one index pattern contains the required data
// We can skip the LLM and return the index pattern directly
return new Command({
update: {
selectedIndexPattern: candidateIndexPatterns[0].indexPattern,
},
});
}
// More than one index pattern contains the required data, we will pick the shortest one (this is likely to be the least specific)
return new Command({
update: {
selectedIndexPattern: candidateIndexPatterns.sort(
(a, b) => a.indexPattern.length - b.indexPattern.length
)[0].indexPattern,
},
});
};
};

View file

@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatBedrockConverse,
ActionsClientChatVertexAI,
ActionsClientChatOpenAI,
} from '@kbn/langchain/server';
import { HumanMessage, SystemMessage } from '@langchain/core/messages';
import { Command } from '@langchain/langgraph';
import { z } from '@kbn/zod';
import type { SelectIndexPatternAnnotation } from '../../state';
const ShortlistedIndexPatterns = z
.object({
shortlistedIndexPatterns: z.array(z.string()).describe('Shortlisted index patterns'),
})
.describe(
'Object containing array of shortlisted index patterns that might be used to generate the query'
);
export const getShortlistIndexPatterns = ({
createLlmInstance,
}: {
createLlmInstance: () =>
| ActionsClientChatBedrockConverse
| ActionsClientChatVertexAI
| ActionsClientChatOpenAI;
}) => {
const llm = createLlmInstance();
return async (state: typeof SelectIndexPatternAnnotation.State) => {
const systemMessage = new SystemMessage({
content: `You are a security analyst who is an expert in Elasticsearch and particularly writing Elastic Search queries. You have been given a list of index patterns and an explanation of the query we would like to generate.
To generate the query we first need to identify which index pattern should be used. To do this you short list a maximum of 3 index patterns that are the most likely to contain the fields required to write the query. Select a variety index patterns.`,
});
const humanMessage = new HumanMessage({
content: `Available index patterns:\n ${state.indexPatterns.join(
'\n'
)} \n\n Explanation of the query: \n\n ${
state.input?.question
} \n\n Based on this information, please shortlist a maximum of 3 index patterns that are the most likely to contain the fields required to write the query.`,
});
try {
const result = await llm
.withStructuredOutput(ShortlistedIndexPatterns, { name: 'shortlistedIndexPatterns' })
.withRetry({
stopAfterAttempt: 3,
})
.invoke([systemMessage, humanMessage]);
return new Command({
update: {
shortlistedIndexPatterns: result.shortlistedIndexPatterns,
},
});
} catch (error) {
return new Command({});
}
};
};

View file

@ -0,0 +1,90 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { START, StateGraph, Send, END } from '@langchain/langgraph';
import type {
ActionsClientChatBedrockConverse,
ActionsClientChatVertexAI,
ActionsClientChatOpenAI,
} from '@kbn/langchain/server';
import type { ElasticsearchClient } from '@kbn/core/server';
import { SelectIndexPatternAnnotation } from './state';
import {
ANALYZE_INDEX_PATTERN,
GET_INDEX_PATTERNS,
SELECT_INDEX_PATTERN,
SHORTLIST_INDEX_PATTERNS,
} from './constants';
import { fetchIndexPatterns } from './nodes/fetch_index_patterns/fetch_index_patterns';
import { getShortlistIndexPatterns } from './nodes/shortlist_index_patterns/shortlist_index_patterns';
import { getAnalyzeIndexPattern } from './nodes/analyse_index_pattern/analyse_index_pattern';
import { getSelectIndexPattern } from './nodes/select_index/select_index';
import { getAnalyzeIndexPatternGraph } from '../analyse_index_pattern/analyse_index_pattern';
export const getSelectIndexPatternGraph = ({
createLlmInstance,
esClient,
}: {
createLlmInstance: () =>
| ActionsClientChatBedrockConverse
| ActionsClientChatVertexAI
| ActionsClientChatOpenAI;
esClient: ElasticsearchClient;
}) => {
const analyzeIndexPatternGraph = getAnalyzeIndexPatternGraph({
esClient,
createLlmInstance,
});
const graph = new StateGraph(SelectIndexPatternAnnotation)
.addNode(GET_INDEX_PATTERNS, fetchIndexPatterns({ esClient }), {
retryPolicy: { maxAttempts: 3 },
})
.addNode(SHORTLIST_INDEX_PATTERNS, getShortlistIndexPatterns({ createLlmInstance }))
.addNode(
ANALYZE_INDEX_PATTERN,
getAnalyzeIndexPattern({
analyzeIndexPatternGraph,
}),
{ retryPolicy: { maxAttempts: 3 }, subgraphs: [analyzeIndexPatternGraph] }
)
.addNode(SELECT_INDEX_PATTERN, getSelectIndexPattern({ createLlmInstance }), {
retryPolicy: { maxAttempts: 3 },
})
.addEdge(START, GET_INDEX_PATTERNS)
.addEdge(GET_INDEX_PATTERNS, SHORTLIST_INDEX_PATTERNS)
.addConditionalEdges(
SHORTLIST_INDEX_PATTERNS,
(state: typeof SelectIndexPatternAnnotation.State) => {
const { input } = state;
if (input === undefined) {
throw new Error('State input is undefined');
}
if (state.shortlistedIndexPatterns.length === 0) {
return END;
}
return state.shortlistedIndexPatterns.map((indexPattern) => {
return new Send(ANALYZE_INDEX_PATTERN, {
input: {
question: input.question,
indexPattern,
},
});
});
},
{
[ANALYZE_INDEX_PATTERN]: ANALYZE_INDEX_PATTERN,
[END]: END,
}
)
.addEdge(ANALYZE_INDEX_PATTERN, SELECT_INDEX_PATTERN)
.addEdge(SELECT_INDEX_PATTERN, END)
.compile();
return graph;
};

View file

@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { BaseMessage } from '@langchain/core/messages';
import { Annotation, messagesStateReducer } from '@langchain/langgraph';
export const SelectIndexPatternAnnotation = Annotation.Root({
messages: Annotation<BaseMessage[]>({
reducer: messagesStateReducer,
default: () => [],
}),
input: Annotation<{ question: string; indexPattern?: string } | undefined>({
reducer: (currentValue, newValue) => newValue ?? currentValue,
default: () => undefined,
}),
indexPatterns: Annotation<string[]>({
reducer: (currentValue, newValue) => newValue ?? currentValue,
default: () => [],
}),
shortlistedIndexPatterns: Annotation<string[]>({
reducer: (currentValue, newValue) => newValue ?? currentValue,
default: () => [],
}),
indexPatternAnalysis: Annotation<
Record<string, { containsRequiredData: boolean; indexPattern: string; context: string }>
>({
reducer: (currentValue, newValue) => ({ ...currentValue, ...newValue }),
default: () => ({}),
}),
selectedIndexPattern: Annotation<string | undefined | null>({
reducer: (currentValue, newValue) => (newValue === undefined ? currentValue : newValue),
default: () => undefined,
}),
});

View file

@ -1,108 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { RetrievalQAChain } from 'langchain/chains';
import type { DynamicTool } from '@langchain/core/tools';
import { NL_TO_ESQL_TOOL } from './nl_to_esql_tool';
import type { ElasticsearchClient } from '@kbn/core-elasticsearch-server';
import type { KibanaRequest } from '@kbn/core-http-server';
import type { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common/impl/schemas/actions_connector/post_actions_connector_execute_route.gen';
import { loggerMock } from '@kbn/logging-mocks';
import { getPromptSuffixForOssModel } from './common';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import type { ContentReferencesStore } from '@kbn/elastic-assistant-common';
describe('NaturalLanguageESQLTool', () => {
const chain = {} as RetrievalQAChain;
const esClient = {
search: jest.fn().mockResolvedValue({}),
} as unknown as ElasticsearchClient;
const request = {
body: {
isEnabledKnowledgeBase: false,
alertsIndexPattern: '.alerts-security.alerts-default',
allow: ['@timestamp', 'cloud.availability_zone', 'user.name'],
allowReplacement: ['user.name'],
replacements: { key: 'value' },
size: 20,
},
} as unknown as KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
const logger = loggerMock.create();
const inference = {} as InferenceServerStart;
const connectorId = 'fake-connector';
const contentReferencesStore = {} as ContentReferencesStore;
const rest = {
chain,
esClient,
logger,
request,
inference,
connectorId,
isEnabledKnowledgeBase: true,
contentReferencesStore,
};
describe('isSupported', () => {
it('returns true if connectorId and inference have values', () => {
expect(NL_TO_ESQL_TOOL.isSupported(rest)).toBe(true);
});
});
describe('getTool', () => {
it('returns null if inference plugin is not provided', () => {
const tool = NL_TO_ESQL_TOOL.getTool({
...rest,
inference: undefined,
});
expect(tool).toBeNull();
});
it('returns null if connectorId is not provided', () => {
const tool = NL_TO_ESQL_TOOL.getTool({
...rest,
connectorId: undefined,
});
expect(tool).toBeNull();
});
it('should return a Tool instance when given required properties', () => {
const tool = NL_TO_ESQL_TOOL.getTool({
...rest,
});
expect(tool?.name).toEqual('NaturalLanguageESQLTool');
});
it('should return a tool with the expected tags', () => {
const tool = NL_TO_ESQL_TOOL.getTool({
...rest,
}) as DynamicTool;
expect(tool.tags).toEqual(['esql', 'query-generation', 'knowledge-base']);
});
it('should return tool with the expected description for OSS model', () => {
const tool = NL_TO_ESQL_TOOL.getTool({
isOssModel: true,
...rest,
}) as DynamicTool;
expect(tool.description).toContain(getPromptSuffixForOssModel('NaturalLanguageESQLTool'));
});
it('should return tool with the expected description for non-OSS model', () => {
const tool = NL_TO_ESQL_TOOL.getTool({
isOssModel: false,
...rest,
}) as DynamicTool;
expect(tool.description).not.toContain(getPromptSuffixForOssModel('NaturalLanguageESQLTool'));
});
});
});

View file

@ -10,11 +10,14 @@ import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-
import { lastValueFrom } from 'rxjs';
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
import { z } from '@kbn/zod';
import type { ElasticAssistantApiRequestHandlerContext } from '@kbn/elastic-assistant-plugin/server/types';
import { APP_UI_ID } from '../../../../common';
import { getPromptSuffixForOssModel } from './common';
import { getPromptSuffixForOssModel } from './utils/common';
// select only some properties of AssistantToolParams
export type ESQLToolParams = AssistantToolParams;
export type ESQLToolParams = AssistantToolParams & {
assistantContext: ElasticAssistantApiRequestHandlerContext;
};
const TOOL_NAME = 'NaturalLanguageESQLTool';
@ -35,11 +38,16 @@ const toolDetails = {
export const NL_TO_ESQL_TOOL: AssistantTool = {
...toolDetails,
sourceRegister: APP_UI_ID,
isSupported: (params: ESQLToolParams): params is ESQLToolParams => {
const { inference, connectorId } = params;
return inference != null && connectorId != null;
isSupported: (params: AssistantToolParams): params is ESQLToolParams => {
const { inference, connectorId, assistantContext } = params;
return (
inference != null &&
connectorId != null &&
assistantContext != null &&
!assistantContext.getRegisteredFeatures('securitySolutionUI').advancedEsqlGeneration
);
},
getTool(params: ESQLToolParams) {
getTool(params: AssistantToolParams) {
if (!this.isSupported(params)) return null;
const { connectorId, inference, logger, request, isOssModel } = params as ESQLToolParams;

View file

@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { ElasticsearchClient } from '@kbn/core/server';
import { tool } from '@langchain/core/tools';
export const toolDetails = {
name: 'available_index_names',
description:
'Get the available indices in the elastic search cluster. Use this when there is an unknown index error or you need to get the available indices.',
};
export const getIndexNamesTool = ({ esClient }: { esClient: ElasticsearchClient }) => {
return tool(
async () => {
const indicesResolveIndexResponse = await esClient.indices.resolveIndex({
name: '*',
expand_wildcards: 'open',
});
const resolvedIndexNames = Object.values(indicesResolveIndexResponse)
.flat()
.map((item) => item.name as string)
.sort((a, b) => {
if (a.startsWith('.') && !b.startsWith('.')) return 1;
if (!a.startsWith('.') && b.startsWith('.')) return -1;
return a.localeCompare(b);
});
return `You can use the wildcard character "*" to query multiple indices at once. For example, if you want to query all logs indices that start with "logs-", you can use "logs-*". If the precice index was not specified in the task, it is best to make a more general query using a wildcard. Bellow are the available indecies:
${resolvedIndexNames.join('\n')}`;
},
{
name: toolDetails.name,
description: toolDetails.description,
}
);
};

View file

@ -0,0 +1,75 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { ElasticsearchClient } from '@kbn/core/server';
import { tool } from '@langchain/core/tools';
import { z } from '@kbn/zod';
import { IndexPatternsFetcher } from '@kbn/data-views-plugin/server';
import {
getNestedValue,
mapFieldDescriptorToNestedObject,
shallowObjectViewTruncated,
} from './inspect_index_utils';
export const toolDetails = {
name: 'inspect_index_mapping',
description:
'Use this tool when there is a "verification_exception Unknown column" error or to see which fields and types are used in the index.' +
'This function will return as much of the index mapping as possible. If the index mapping is too large, some values will be truncated as indicated by "...".' +
'Call the function again to inspected truncated values.' +
`Example:
Index mapping:
{"user":{"address":{"city":{"name":{"type":"keyword"},"zip":{"type":"integer"}}}}}}
Call #1:
Property: "" // empty string to get the root
Function response: {"user":{"address":{"city":"...","zip":"..."}}}
Call #2:
Property: "user.address.city"
Function response: {"name":{"type":"keyword"}
`,
};
export const getInspectIndexMappingTool = ({
esClient,
indexPattern,
}: {
esClient: ElasticsearchClient;
indexPattern: string;
}) => {
const indexPatternsFetcher = new IndexPatternsFetcher(esClient);
return tool(
async ({ property }) => {
const { fields } = await indexPatternsFetcher.getFieldsForWildcard({
pattern: indexPattern,
fieldCapsOptions: {
allow_no_indices: false,
includeUnmapped: false,
},
});
const prunedFields = fields.map((p) => ({ name: p.name, type: p.esTypes[0] }));
const nestedObject = mapFieldDescriptorToNestedObject(prunedFields);
const nestedValue = getNestedValue(nestedObject, property);
const result = shallowObjectViewTruncated(nestedValue, 30000);
return result ? JSON.stringify(result) : `No value found for property "${property}".`;
},
{
name: toolDetails.name,
description: toolDetails.description,
schema: z.object({
property: z
.string()
.describe(
`The property to inspect. The property should be a dot-separated path to the field in the index mapping. For example, "user.name" or "user.address.city". Empty string will return the root.`
),
}),
}
);
};

View file

@ -0,0 +1,515 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { compressMapping } from '../../graphs/analyse_index_pattern/nodes/analyze_compressed_index_mapping_agent/compress_mapping';
import {
getNestedValue,
mapFieldDescriptorToNestedObject,
shallowObjectView,
shallowObjectViewTruncated,
} from './inspect_index_utils';
const sampleMapping1 = {
mappings: {
properties: {
field1: {
type: 'keyword',
},
field2: {
properties: {
nested_field: {
type: 'keyword',
},
},
},
},
},
};
describe('inspect index', () => {
it.each([
[
sampleMapping1,
'mappings.properties',
{
field1: {
type: 'keyword',
},
field2: {
properties: {
nested_field: {
type: 'keyword',
},
},
},
},
],
[
sampleMapping1,
'mappings.properties.field1',
{
type: 'keyword',
},
],
[
{
foo: [{ bar: 1 }, { bar: 2 }],
},
'foo.1.bar',
2,
],
[
{
foo: [{ bar: 1 }, { bar: 2 }],
},
'',
{
foo: [{ bar: 1 }, { bar: 2 }],
},
],
])(
'getEntriesAtKey input %s returns %s',
(mapping: unknown, key: string, expectedResult: unknown) => {
expect(getNestedValue(mapping, key)).toEqual(expectedResult);
}
);
it.each([
[
{
type: 'keyword',
},
1,
{
type: 'keyword',
},
],
[
{
field1: {
type: 'keyword',
},
field2: {
properties: {
nested_field: {
type: 'keyword',
},
},
},
},
1,
{
field1: '...',
field2: '...',
},
],
[
{
field1: {
type: 'keyword',
},
field2: {
properties: {
nested_field: {
type: 'keyword',
},
},
},
},
2,
{
field1: {
type: 'keyword',
},
field2: {
properties: '...',
},
},
],
[
{
field1: {
type: 'keyword',
},
field2: {
properties: {
nested_field: {
type: 'keyword',
},
},
},
},
1,
{
field1: '...',
field2: '...',
},
],
[
{
field1: 'keyword',
field2: {
properties: {
nested_field: {
type: 'keyword',
},
},
},
},
1,
{
field1: 'keyword',
field2: '...',
},
],
[
{
field1: [1, 2, 3],
field2: {
properties: {
nested_field: {
type: 'keyword',
},
},
},
},
2,
{
field1: [1, 2, 3],
field2: {
properties: '...',
},
},
],
])(
'shallowObjectView input %s returns %s',
(mapping: unknown, maxDepth: number, expectedResult: unknown) => {
expect(shallowObjectView(mapping, maxDepth)).toEqual(expectedResult);
}
);
it('shallowObjectView returns undefined for undefined mapping', () => {
expect(shallowObjectView(undefined)).toEqual(undefined);
});
it('shallowObjectView returns mapping for string mapping', () => {
expect(shallowObjectView('string')).toEqual('string');
});
it('shallowObjectView returns Object for maxDepth 0', () => {
expect(shallowObjectView(sampleMapping1, 0)).toEqual('...');
});
it('shallowObjectViewTruncated returns truncated view', () => {
expect(shallowObjectViewTruncated(sampleMapping1, 10)).toEqual({
mappings: '...',
});
});
it('shallowObjectViewTruncated does not reduce depth if maxCharacters is not exceeded', () => {
expect(shallowObjectViewTruncated(sampleMapping1, 200)).toEqual({
mappings: {
properties: {
field1: {
type: 'keyword',
},
field2: {
properties: '...',
},
},
},
});
});
it('shallowObjectViewTruncated reduces depth if maxCharacters is exceeded', () => {
expect(shallowObjectViewTruncated(sampleMapping1, 50)).toEqual({
mappings: {
properties: '...',
},
});
});
it('fieldDescriptor maps to nested object', () => {
const fieldDescriptors = [
{
name: '@timestamp',
type: 'date',
esTypes: ['date'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
fixedInterval: undefined,
timeZone: undefined,
timeSeriesMetric: undefined,
timeSeriesDimension: undefined,
},
{
name: 'effective_process.entity_id',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
fixedInterval: undefined,
timeZone: undefined,
timeSeriesMetric: undefined,
timeSeriesDimension: undefined,
},
{
name: 'effective_process.executable',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
fixedInterval: undefined,
timeZone: undefined,
timeSeriesMetric: undefined,
timeSeriesDimension: undefined,
},
];
const nestedObject = mapFieldDescriptorToNestedObject(fieldDescriptors);
expect(nestedObject).toEqual({
'@timestamp': {
aggregatable: true,
esTypes: ['date'],
fixedInterval: undefined,
metadata_field: false,
readFromDocValues: true,
searchable: true,
timeSeriesDimension: undefined,
timeSeriesMetric: undefined,
timeZone: undefined,
type: 'date',
},
effective_process: {
entity_id: {
aggregatable: true,
esTypes: ['keyword'],
fixedInterval: undefined,
metadata_field: false,
readFromDocValues: true,
searchable: true,
timeSeriesDimension: undefined,
timeSeriesMetric: undefined,
timeZone: undefined,
type: 'string',
},
executable: {
aggregatable: true,
esTypes: ['keyword'],
fixedInterval: undefined,
metadata_field: false,
readFromDocValues: true,
searchable: true,
timeSeriesDimension: undefined,
timeSeriesMetric: undefined,
timeZone: undefined,
type: 'string',
},
},
});
});
it('1fieldDescriptor maps to nested object', () => {
const fieldDescriptors = [
{
name: 'test',
type: 'number',
esTypes: ['long'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'bar',
type: 'number',
esTypes: ['long'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.Ext.options',
type: 'number',
esTypes: ['long'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.Ext.status',
type: 'number',
esTypes: ['long'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.answers.class',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.answers.data',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.answers.name',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.answers.ttl',
type: 'number',
esTypes: ['long'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.answers.type',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.header_flags',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.id',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.op_code',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.question.class',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.question.name',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.question.registered_domain',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.question.subdomain',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.question.top_level_domain',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.question.type',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.resolved_ip',
type: 'ip',
esTypes: ['ip'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
{
name: 'dns.response_code',
type: 'string',
esTypes: ['keyword'],
searchable: true,
aggregatable: true,
readFromDocValues: true,
metadata_field: false,
},
];
const nestedObject = mapFieldDescriptorToNestedObject(
fieldDescriptors.map((p) => ({ name: p.name, type: p.esTypes[0] }))
);
const result = compressMapping(nestedObject);
expect(result).toEqual(
`test,bar:long\ndns:{header_flags,id,op_code,response_code:keyword,resolved_ip:ip,Ext:{options,status:long},answers:{class,data,name,type:keyword,ttl:long},question:{class,name,registered_domain,subdomain,top_level_domain,type:keyword}}`
);
});
});

View file

@ -0,0 +1,101 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { get } from 'lodash';
/**
* Gets the entries at a given key in an index mapping
*/
export const getNestedValue = (obj: unknown, keyPath: string) => {
return keyPath ? get(obj, keyPath) : obj;
};
/**
* Returns a shallow view of the object
* @param obj The object
* @param maxDepth The maximum depth to recurse into the object
* @returns A shallow view of the mapping
*/
export const shallowObjectView = (obj: unknown, maxDepth = 1): object | string | undefined => {
if (
obj === undefined ||
typeof obj === 'string' ||
typeof obj === 'number' ||
typeof obj === 'boolean'
) {
return obj?.toString() ?? undefined;
}
if (Array.isArray(obj)) {
return maxDepth <= 0 ? '...' : obj;
}
if (typeof obj === 'object' && obj !== null) {
if (maxDepth <= 0) {
return '...';
}
return Object.fromEntries(
Object.entries(obj).map(([key, value]) => [
key,
typeof value === 'object'
? shallowObjectView(value, maxDepth - 1)
: value?.toString() ?? undefined,
])
);
}
return 'unknown';
};
/**
* Same as shallowObjectView but reduces the maxDepth if the stringified view is longer than maxCharacters
* @param mapping The index mapping
* @param maxCharacters The maximum number of characters to return
* @param maxDepth The maximum depth to recurse into the object
* @returns A shallow view of the mapping
*/
export const shallowObjectViewTruncated = (
obj: unknown,
maxCharacters: number,
maxDepth = 4
): object | string | undefined => {
const view = shallowObjectView(obj, maxDepth);
if (maxDepth > 1 && view && JSON.stringify(view).length > maxCharacters) {
return shallowObjectViewTruncated(view, maxCharacters, maxDepth - 1);
}
return view;
};
interface TypedProperty {
type: string;
[key: string]: unknown;
}
interface NestedObject {
[key: string]: TypedProperty | NestedObject;
}
export const mapFieldDescriptorToNestedObject = <T extends { name: string; type: string }>(
arr: T[]
): NestedObject => {
return arr.reduce<NestedObject>((acc, obj) => {
const keys = obj.name.split('.');
keys.reduce((nested: NestedObject, key, index) => {
if (!(key in nested)) {
nested[key] =
index === keys.length - 1
? (Object.fromEntries(
Object.entries(obj).filter(([k]) => k !== 'name')
) as TypedProperty)
: {};
}
return nested[key] as NestedObject;
}, acc);
return acc;
}, {});
};

View file

@ -0,0 +1,85 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { BaseMessage } from '@langchain/core/messages';
import { AIMessage } from '@langchain/core/messages';
import { requireFirstInspectIndexMappingCallWithEmptyKey } from './common';
const aiMessage1 = new AIMessage({
content: 'test',
tool_calls: [
{
name: 'inspect_index_mapping',
args: {
property: '',
},
},
],
});
const aiMessage2 = new AIMessage({
content: 'test',
tool_calls: [
{
name: 'inspect_index_mapping',
args: {
property: '',
},
},
{
name: 'inspect_index_mapping',
args: {
property: 'test',
},
},
],
});
const aiMessage3 = new AIMessage({
content: 'test',
tool_calls: [
{
name: 'inspect_index_mapping',
args: {
property: 'test',
},
},
],
});
const aiMessage4 = new AIMessage({
content: 'test',
tool_calls: [
{
name: 'inspect_index_mapping',
args: {
property: 'foo',
},
},
{
name: 'inspect_index_mapping',
args: {
property: 'test',
},
},
],
});
describe('common', () => {
it.each([
[aiMessage3, [], aiMessage1],
[aiMessage2, [], aiMessage2],
[aiMessage4, [], aiMessage2],
])(
'requireFirstInspectIndexMappingCallWithEmptyKey',
(newMessage: AIMessage, oldMessage: BaseMessage[], expected: AIMessage) => {
const result = requireFirstInspectIndexMappingCallWithEmptyKey(newMessage, oldMessage);
expect(result).toEqual(expected);
}
);
});

View file

@ -0,0 +1,87 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type {
ActionsClientChatBedrockConverse,
ActionsClientChatVertexAI,
ActionsClientChatOpenAI,
} from '@kbn/langchain/server';
import type { BaseMessage } from '@langchain/core/messages';
import { AIMessage } from '@langchain/core/messages';
import type { ToolCall } from '@langchain/core/dist/messages/tool';
import { toolDetails } from '../tools/inspect_index_mapping_tool/inspect_index_mapping_tool';
export const getPromptSuffixForOssModel = (toolName: string) => `
When using ${toolName} tool ALWAYS pass the user's questions directly as input into the tool.
Always return value from ${toolName} tool as is.
The ES|QL query should ALWAYS be wrapped in triple backticks ("\`\`\`esql"). Add a new line character right before the triple backticks.
It is important that ES|QL query is preceeded by a new line.`;
export const messageContainsToolCalls = (message: BaseMessage): message is AIMessage => {
return (
'tool_calls' in message && Array.isArray(message.tool_calls) && message.tool_calls?.length > 0
);
};
export type CreateLlmInstance = () =>
| ActionsClientChatBedrockConverse
| ActionsClientChatVertexAI
| ActionsClientChatOpenAI;
export const requireFirstInspectIndexMappingCallWithEmptyKey = (
newMessage: AIMessage,
oldMessages: BaseMessage[]
): AIMessage => {
const hasCalledInspectIndexMappingTool = oldMessages.find((message) => {
return (
messageContainsToolCalls(message) &&
message.tool_calls?.some((toolCall) => {
return toolCall.name === toolDetails.name;
})
);
});
if (hasCalledInspectIndexMappingTool) {
return newMessage;
}
const newMessageToolCalls = newMessage.tool_calls || [];
const containsFirstInspectIndexMappingCall = newMessageToolCalls.some((toolCall) => {
return toolCall.name === toolDetails.name;
});
if (!containsFirstInspectIndexMappingCall) {
return newMessage;
}
const modifiedToolCalls: ToolCall[] = [];
let hasModifiedToolCall = false;
for (const toolCall of newMessageToolCalls) {
if (toolCall.name === toolDetails.name && !hasModifiedToolCall) {
modifiedToolCalls.push({
...toolCall,
args: {
...toolCall.args,
property: '',
},
});
hasModifiedToolCall = true;
} else {
modifiedToolCalls.push(toolCall);
}
}
return new AIMessage({
content: newMessage.content,
tool_calls: modifiedToolCalls,
});
};

View file

@ -6,6 +6,8 @@
*/
import { PRODUCT_DOCUMENTATION_TOOL } from './product_docs/product_documentation_tool';
import { GENERATE_ESQL_TOOL } from './esql/generate_esql_tool';
import { ASK_ABOUT_ESQL_TOOL } from './esql/ask_about_esql_tool';
import { NL_TO_ESQL_TOOL } from './esql/nl_to_esql_tool';
import { ALERT_COUNTS_TOOL } from './alert_counts/alert_counts_tool';
import { OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL } from './open_and_acknowledged_alerts/open_and_acknowledged_alerts_tool';
@ -19,6 +21,8 @@ export const assistantTools = [
ALERT_COUNTS_TOOL,
KNOWLEDGE_BASE_RETRIEVAL_TOOL,
KNOWLEDGE_BASE_WRITE_TOOL,
GENERATE_ESQL_TOOL,
ASK_ABOUT_ESQL_TOOL,
NL_TO_ESQL_TOOL,
OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL,
PRODUCT_DOCUMENTATION_TOOL,

View file

@ -594,6 +594,7 @@ export class Plugin implements ISecuritySolutionPlugin {
plugins.elasticAssistant.registerTools(APP_UI_ID, assistantTools);
const features = {
assistantModelEvaluation: config.experimentalFeatures.assistantModelEvaluation,
advancedEsqlGeneration: config.experimentalFeatures.advancedEsqlGeneration,
};
plugins.elasticAssistant.registerFeatures(APP_UI_ID, features);
plugins.elasticAssistant.registerFeatures('management', features);

View file

@ -239,6 +239,9 @@
"@kbn/product-doc-base-plugin",
"@kbn/shared-ux-error-boundary",
"@kbn/security-ai-prompts",
"@kbn/inference-common",
"@kbn/esql-ast",
"@kbn/inference-langchain",
"@kbn/scout-security",
"@kbn/custom-icons",
"@kbn/security-plugin-types-common",