esql self healing graph

This commit is contained in:
Kenneth Kreindler 2025-03-07 12:09:05 +00:00
parent ad3dce1179
commit c76a4469ae
No known key found for this signature in database
GPG key ID: 429CB8689E46A00B
16 changed files with 623 additions and 28 deletions

View file

@ -0,0 +1,48 @@
import { BaseMessage as LangChainBaseMessage, AIMessage as LangChainAIMessage, ToolMessage as LangChainToolMessage } from '@langchain/core/messages';
import { AssistantMessage as InferenceAssistantMessage, Message as InferenceMessage, MessageRole, ToolCall, ToolMessage, UserMessage } from '@kbn/inference-common';
import { isEmpty } from 'lodash';
export const langchainMessageToInferenceMessage = (langChainMessage: LangChainBaseMessage): InferenceMessage => {
switch (langChainMessage.getType()) {
case "system":
case "generic":
case "developer":
case "human": {
return {
role: MessageRole.User,
content: langChainMessage.content
} as UserMessage
}
case "ai": {
if (langChainMessage instanceof LangChainAIMessage) {
const toolCalls: ToolCall[] | undefined = langChainMessage.tool_calls?.map((toolCall): ToolCall => ({
toolCallId: toolCall.id as string,
function: {
...(!isEmpty(toolCall.args) ? { arguments: toolCall.args } : {}),
name: toolCall.name
}
}))
return {
role: MessageRole.Assistant,
content: langChainMessage.content as string,
...(!isEmpty(toolCalls) ? { toolCalls } : {})
} as InferenceAssistantMessage
}
throw new Error(`Unable to convert LangChain message of type ${langChainMessage.getType()} to Inference message`)
}
case "tool": {
if (langChainMessage instanceof LangChainToolMessage) {
return {
name: langChainMessage.name,
toolCallId: langChainMessage.tool_call_id,
role: MessageRole.Tool,
response: langChainMessage.content,
} as ToolMessage
}
throw new Error(`Unable to convert LangChain message of type ${langChainMessage.getType()} to Inference message`)
}
default: {
throw new Error(`Unable to convert LangChain message of type ${langChainMessage.getType()} to Inference message`)
}
}
}

View file

@ -0,0 +1,24 @@
import { ToolDefinition as InferenceToolDefinition, ToolSchema as InferenceToolSchema } from '@kbn/inference-common';
import { StructuredToolInterface as LangChainStructuredToolInterface } from '@langchain/core/tools';
import { pick } from 'lodash';
import { ZodSchema } from 'zod';
import zodToJsonSchema from 'zod-to-json-schema';
export const langchainToolToInferenceTool = (langchainTool: LangChainStructuredToolInterface): InferenceToolDefinition => {
const schema = langchainTool.schema ? zodSchemaToInference(langchainTool.schema) : undefined;
return {
description: langchainTool.description,
...(schema ? { schema } : {}),
}
}
export const langchainToolsToInferenceTools = (langchainTool: LangChainStructuredToolInterface[]) : Record<string, InferenceToolDefinition> => {
return langchainTool.reduce((acc, tool) => {
acc[tool.name] = langchainToolToInferenceTool(tool);
return acc;
}, {} as Record<string, InferenceToolDefinition>);
}
function zodSchemaToInference(schema: ZodSchema): InferenceToolSchema {
return pick(zodToJsonSchema(schema), ['type', 'properties', 'required']) as InferenceToolSchema;
}

View file

@ -0,0 +1,26 @@
import { ChatCompletionMessageEvent } from "@kbn/inference-common";
import { AIMessage, BaseMessage as LangChainBaseMessage } from '@langchain/core/messages';
import { isEmpty } from "lodash";
export const nlToEsqlTaskEventToLangchainMessage = (taskEvent: ChatCompletionMessageEvent): LangChainBaseMessage => {
switch (taskEvent.type) {
case "chatCompletionMessage": {
const toolCalls = taskEvent.toolCalls.map((toolCall) => {
return {
name: toolCall.function.name,
args: toolCall.function.arguments,
id: toolCall.toolCallId,
type: "tool_call"
} as const
})
return new AIMessage({
content: taskEvent.content,
...(!isEmpty(toolCalls) ? { tool_calls: toolCalls } : {})
})
}
default: {
throw new Error(`Unable to convert nlToEsqlTaskEvent of type ${taskEvent.type} to LangChain message`)
}
}
}

View file

@ -42,6 +42,14 @@ export type {
export { transformRawData } from './impl/data_anonymization/transform_raw_data';
export { parseBedrockBuffer, handleBedrockChunk } from './impl/utils/bedrock';
export { langchainMessageToInferenceMessage } from './impl/utils/langchain_message_to_inference_message';
export {
langchainToolToInferenceTool,
langchainToolsToInferenceTools
} from './impl/utils/langchain_tool_to_inference_tool';
export {
nlToEsqlTaskEventToLangchainMessage
} from './impl/utils/nl_to_esql_task_event_to_langchain_message';
export * from './constants';
/** currently the same shape as "fields" property in the ES response */

View file

@ -0,0 +1,61 @@
import { AssistantMessage, MessageRole, UserMessage, ToolMessage } from "@kbn/inference-common"
import { ensureMultiTurn } from "./ensure_multi_turn"
const assistantMessage: AssistantMessage = {
role: MessageRole.Assistant,
content: "Hello from assistant"
}
const userMessage: UserMessage = {
role: MessageRole.User,
content: "Hello from user"
}
const toolMessage: ToolMessage = {
role: MessageRole.Tool,
response: "Hello from tool",
toolCallId: "123",
name: "toolName"
}
const intermediaryUserMessage: UserMessage = {
role: MessageRole.User,
content: "-"
}
const intemediaryAssistantMessage: AssistantMessage = {
role: MessageRole.Assistant,
content: "-"
}
describe("ensureMultiTurn", () => {
it("returns correct value for message sequence", () => {
const messages = [
assistantMessage,
assistantMessage,
userMessage,
userMessage,
toolMessage,
toolMessage,
userMessage,
assistantMessage,
toolMessage]
const result = ensureMultiTurn(messages)
expect(result).toEqual([
assistantMessage,
intermediaryUserMessage,
assistantMessage,
userMessage,
intemediaryAssistantMessage,
userMessage,
toolMessage,
toolMessage,
userMessage,
assistantMessage,
toolMessage
])
})
})

View file

@ -7,20 +7,65 @@
import { Message, MessageRole } from '@kbn/inference-common';
function isUserMessage(message: Message): boolean {
return message.role !== MessageRole.Assistant;
type MessageRoleSequenceResult = {
roleSequenceValid: true;
} | {
roleSequenceValid: false;
intermediaryRole: MessageRole.User | MessageRole.Assistant;
};
/**
* Two consecutive messages with USER role or two consecutive messages with ASSISTANT role are not allowed.
* Consecutive messages with TOOL role are allowed.
*/
function checkMessageRoleSequenceValid(prevMessage: Message | undefined, message: Message): MessageRoleSequenceResult {
if (!prevMessage) {
return {
roleSequenceValid: true
};
}
if (prevMessage.role === MessageRole.User && message.role === MessageRole.User) {
return {
roleSequenceValid: false,
intermediaryRole: MessageRole.Assistant
};
}
if (prevMessage.role === MessageRole.Assistant && message.role === MessageRole.Assistant) {
return {
roleSequenceValid: false,
intermediaryRole: MessageRole.User
};
}
if (prevMessage.role === MessageRole.Tool && message.role === MessageRole.Tool) {
return {
roleSequenceValid: true
};
}
if (prevMessage.role === message.role) {
return {
roleSequenceValid: false,
intermediaryRole: MessageRole.Assistant
}
}
return {
roleSequenceValid: true
}
}
export function ensureMultiTurn(messages: Message[]): Message[] {
const next: Message[] = [];
messages.forEach((message) => {
const prevMessage = next[next.length - 1];
if (prevMessage && isUserMessage(prevMessage) === isUserMessage(message)) {
const result = checkMessageRoleSequenceValid(prevMessage, message);
if (!result.roleSequenceValid) {
next.push({
content: '-',
role: isUserMessage(message) ? MessageRole.Assistant : MessageRole.User,
role: result.intermediaryRole,
});
}

View file

@ -0,0 +1,3 @@
export const NL_TO_ESQL_AGENT_NODE = "nl_to_esql_agent";
export const TOOLS_NODE = "tools";
export const ESQL_VALIDATOR_NODE = "esql_validator";

View file

@ -0,0 +1,68 @@
import { StructuredToolInterface } from "@langchain/core/tools";
import { getIndexNamesTool } from "./index_names_tool";
import { ElasticsearchClient, KibanaRequest } from "@kbn/core/server";
import { ToolNode } from "@langchain/langgraph/prebuilt";
import { END, START, StateGraph } from "@langchain/langgraph";
import { EsqlSelfHealingAnnotation } from "./state";
import { ESQL_VALIDATOR_NODE, NL_TO_ESQL_AGENT_NODE, TOOLS_NODE } from "./constants";
import { stepRouter } from "./step_router";
import { getNlToEsqlAgent } from "./nl_to_esql_agent";
import { InferenceServerStart } from "@kbn/inference-plugin/server";
import { Logger } from '@kbn/core/server';
import { getValidatorNode } from "./validator";
import { getInspectIndexMappingTool } from "./inspect_index_mapping_tool";
export const getEsqlSelfHealingGraph = ({
esClient,
connectorId,
inference,
logger,
request,
}: {
esClient: ElasticsearchClient
connectorId: string,
inference: InferenceServerStart,
logger: Logger,
request: KibanaRequest,
}
) => {
const availableIndexNamesTool = getIndexNamesTool({
esClient
})
const inspectIndexMappingTool = getInspectIndexMappingTool({
esClient
})
const tools: StructuredToolInterface[] = [
availableIndexNamesTool,
inspectIndexMappingTool
]
const toolNode = new ToolNode(tools)
const nlToEsqlAgentNode = getNlToEsqlAgent({
connectorId,
inference,
logger,
request,
tools,
})
const validatorNode = getValidatorNode({
esClient
})
const graph = new StateGraph(EsqlSelfHealingAnnotation)
.addNode(NL_TO_ESQL_AGENT_NODE, nlToEsqlAgentNode)
.addNode(TOOLS_NODE, toolNode)
.addNode(ESQL_VALIDATOR_NODE, validatorNode, {
ends: [END, NL_TO_ESQL_AGENT_NODE]
})
.addEdge(START, NL_TO_ESQL_AGENT_NODE)
.addEdge(TOOLS_NODE, NL_TO_ESQL_AGENT_NODE)
.addConditionalEdges(NL_TO_ESQL_AGENT_NODE, stepRouter, {
[TOOLS_NODE]: TOOLS_NODE,
[ESQL_VALIDATOR_NODE]: ESQL_VALIDATOR_NODE
}).compile()
return graph
}

View file

@ -0,0 +1,27 @@
import { ElasticsearchClient } from "@kbn/core/server";
import { tool } from "@langchain/core/tools";
const toolDetails = {
name: "available_index_names",
description: "Get the available indices in the elastic search cluster. Use this when there is an unknown index error or you need to get the indeces that can be queried. Using the response select an appropriate index name."
}
export const getIndexNamesTool = ({
esClient
}: {
esClient: ElasticsearchClient
}) => {
return tool(async () => {
const indexNames = await esClient.cat.indices({
format: 'json'
}).then((response) => response
.map((index) => index.index)
.filter((index) => index!=undefined)
.toSorted()
)
return `These are the names of the available indeces. To query them, you must use the full index name verbatim.\n\n${indexNames.join('\n')}`
}, {
name: toolDetails.name,
description: toolDetails.description
})
}

View file

@ -0,0 +1,116 @@
import { ElasticsearchClient } from "@kbn/core/server"
import { tool } from "@langchain/core/tools"
import { z } from "zod"
const toolDetails = {
name: "inspect_index_mapping",
description: `Use this tool to inspect an index mapping. The tool with fetch the mappings of the provided indexName and then return the data at the given propertyKey. For example:
Example index mapping:
\`\`\`
{
"mappings": {
"properties": {
"field1": {
"type": "keyword"
},
"field2": {
"properties": {
"nested_field": {
"type": "keyword"
}
}
}
}
}
}
\`\`\`
Input:
\`\`\`
{
"indexName": "my_index",
"propertyKey": "mappings.properties"
}
\`\`\`
Output:
\`\`\`
{
"field1": "Object",
"field2": "Object"
}
\`\`\
The tool can be called repeatedly to explode objects and arrays. For example:
Input:
\`\`\`
{
"indexName": "my_index",
"propertyKey": "mappings.properties.field1"
}
\`\`\`
Output:
\`\`\`
{
"type": "keyword",
}
\`\`\``
}
export const getInspectIndexMappingTool = ({
esClient
}: {
esClient: ElasticsearchClient
}) => {
return tool(async ({
indexName,
propertyKey
}) => {
const indexMapping = await esClient.indices.getMapping({
index: indexName
})
const entriesAtKey = getEntriesAtKey(indexMapping[indexName], propertyKey.split("."))
const result = formatEntriesAtKey(entriesAtKey)
return `Object at ${propertyKey} \n${JSON.stringify(result, null, 2)}`
}, {
name: toolDetails.name,
description: toolDetails.description,
schema: z.object({
indexName: z.string().describe(`The index name to get the properties of.`),
propertyKey: z.string().optional().default("mappings.properties").describe(`The key to get the properties of.`)
})
})
}
const getEntriesAtKey = (mapping: Record<string, any> | undefined, keys: string[]): Record<string, any> | undefined => {
if (mapping === undefined) {
return
}
if (keys.length === 0) {
return mapping
}
const key = keys.shift()
if (key === undefined) {
return mapping
}
return getEntriesAtKey(mapping[key], keys)
}
const formatEntriesAtKey = (mapping: Record<string, any> | undefined): Record<string, string> => {
if (mapping === undefined) {
return {}
}
return Object.entries(mapping).reduce((acc, [key, value]) => {
acc[key] = typeof value === "string" ? value : "Object"
return acc
}, {} as Record<string, string>)
}

View file

@ -1,5 +0,0 @@
describe("natural_lanagauge_to_esql_validator", ()=>{
it("test1", ()=>{
})
})

View file

@ -0,0 +1,44 @@
import { lastValueFrom } from "rxjs";
import { EsqlSelfHealingAnnotation } from "./state"
import { langchainMessageToInferenceMessage, langchainToolsToInferenceTools, nlToEsqlTaskEventToLangchainMessage } from "@kbn/elastic-assistant-common";
import { StructuredToolInterface } from "@langchain/core/tools";
import { KibanaRequest } from "@kbn/core/server";
import { InferenceServerStart, naturalLanguageToEsql } from "@kbn/inference-plugin/server";
import { Logger } from '@kbn/core/server';
import { ChatCompletionMessageEvent } from "@kbn/inference-common";
import { Command } from "@langchain/langgraph";
export const getNlToEsqlAgent = ({
connectorId,
inference,
logger,
request,
tools,
}:{
connectorId: string,
inference: InferenceServerStart,
logger: Logger,
request: KibanaRequest,
tools: StructuredToolInterface[]
}) => {
return async (state: typeof EsqlSelfHealingAnnotation.State) => {
const { messages: stateMessages } = state;
const inferenceMessages = stateMessages.map(langchainMessageToInferenceMessage)
const result = await lastValueFrom(naturalLanguageToEsql({
client: inference.getClient({ request: request }),
connectorId: connectorId,
functionCalling: 'auto',
logger: logger,
tools: langchainToolsToInferenceTools(tools),
messages: inferenceMessages
})) as ChatCompletionMessageEvent
return new Command({
update:{
messages : [nlToEsqlTaskEventToLangchainMessage(result)]
}
})
}
}

View file

@ -7,17 +7,11 @@
import { tool } from '@langchain/core/tools';
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
import { lastValueFrom } from 'rxjs';
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
import { z } from '@kbn/zod';
import { APP_UI_ID } from '../../../../common';
import { getEsqlFromContent, getPromptSuffixForOssModel } from './common';
import { NlToEsqlTaskEvent } from '@kbn/inference-plugin/server/tasks/nl_to_esql';
import { ToolOptions } from '@kbn/inference-common';
import { parse } from '@kbn/esql-ast';
import isEmpty from 'lodash/isEmpty';
import { ElasticsearchClient } from '@kbn/core/server';
import { NaturalLanguageToEsqlValidator } from './natual_language_to_esql_validator';
import { getPromptSuffixForOssModel } from './common';
import { HumanMessage } from '@langchain/core/messages';
import { getEsqlSelfHealingGraph } from './graph';
// select only some properties of AssistantToolParams
export type ESQLToolParams = AssistantToolParams;
@ -51,22 +45,21 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
const { connectorId, inference, logger, request, isOssModel, esClient } = params as ESQLToolParams;
if (inference == null || connectorId == null) return null;
const naturalLanguageToEsqlValidator = new NaturalLanguageToEsqlValidator({
inference,
const selfHealingGraph = getEsqlSelfHealingGraph({
esClient,
connectorId,
inference,
logger,
request,
esClient,
})
return tool(
async (input) => {
const answer = await naturalLanguageToEsqlValidator.generateEsqlFromNaturalLanguage(input.question);
console.log(`Received response from NL to ESQL tool: ${answer}`)
logger.debug(`Received response from NL to ESQL tool: ${answer}`);
return answer;
async ({ question }) => {
const humanMessage = new HumanMessage({ content: question })
const result = await selfHealingGraph.invoke({ messages: [humanMessage] })
const { messages } = result
const lastMessage = messages[messages.length - 1]
return lastMessage.content
},
{
name: toolDetails.name,

View file

@ -0,0 +1,3 @@
import { MessagesAnnotation } from "@langchain/langgraph";
export const EsqlSelfHealingAnnotation = MessagesAnnotation

View file

@ -0,0 +1,16 @@
import { EsqlSelfHealingAnnotation } from "./state";
import { ESQL_VALIDATOR_NODE, TOOLS_NODE } from "./constants";
export const stepRouter = (state: typeof EsqlSelfHealingAnnotation.State): string => {
const { messages } = state;
const lastMessage = messages[messages.length - 1];
if (
'tool_calls' in lastMessage &&
Array.isArray(lastMessage.tool_calls) &&
lastMessage.tool_calls?.length
) {
return TOOLS_NODE;
}
return ESQL_VALIDATOR_NODE
}

View file

@ -0,0 +1,118 @@
import { ElasticsearchClient } from "@kbn/core/server"
import { EsqlSelfHealingAnnotation } from "./state"
import { getEsqlFromContent } from "./common"
import { Command, END } from "@langchain/langgraph"
import { EditorError, parse } from "@kbn/esql-ast"
import { isEmpty } from "lodash"
import { BaseMessage, HumanMessage } from "@langchain/core/messages"
import { NL_TO_ESQL_AGENT_NODE } from "./constants"
type ValidateEsqlResult = {
isValid: boolean
query: string
parsingErrors?: EditorError[]
executionError?: any
}
export const getValidatorNode = ({
esClient
}: {
esClient: ElasticsearchClient
}) => {
return async (state: typeof EsqlSelfHealingAnnotation.State) => {
const { messages } = state;
const lastMessage = messages[messages.length - 1];
const generatedQueries = getEsqlFromContent(lastMessage.content as string)
if (!generatedQueries.length) {
return new Command({
goto: END
})
}
const validateEsqlResults = await Promise.all(
generatedQueries.map(query => validateEsql(esClient, query))
)
const containsInvalidQueries = validateEsqlResults.some(result => !result.isValid)
if (containsInvalidQueries) {
return new Command({
update: {
messages: [lastMessageWithErrorReport(lastMessage.content as string, validateEsqlResults)]
},
goto: NL_TO_ESQL_AGENT_NODE
})
}
return new Command({
goto: END,
update:{
messages: `${lastMessage.content} \nAll of the queries have been validated and do not need to modified futher.`
}
})
}
}
const lastMessageWithErrorReport = (message: string, validateEsqlResults: ValidateEsqlResult[]): BaseMessage => {
let messageWithErrorReport = message
validateEsqlResults.reverse().forEach((validateEsqlResult) => {
const index = messageWithErrorReport.indexOf('```', messageWithErrorReport.indexOf(validateEsqlResult.query))
const errorMessage = formatValidateEsqlResultToHumanReadable(validateEsqlResult)
messageWithErrorReport = `${messageWithErrorReport.slice(0, index+3)}\n${errorMessage}\n${messageWithErrorReport.slice(index+3)}`
})
return new HumanMessage({
content: messageWithErrorReport
})
}
const formatValidateEsqlResultToHumanReadable = (validateEsqlResult: ValidateEsqlResult) => {
if (validateEsqlResult.isValid) {
return "Query is valid"
}
let errorMessage = "This query has errors:\n"
if (validateEsqlResult.parsingErrors) {
errorMessage += `${validateEsqlResult.parsingErrors.map((error) => error.message).join('\n')}\n`
}
if (validateEsqlResult.executionError) {
errorMessage += `${extractErrorMessage(validateEsqlResult.executionError)}\n`
}
return errorMessage
}
const validateEsql = async (esClient: ElasticsearchClient, query: string): Promise<ValidateEsqlResult> => {
const { errors: parsingErrors } = parse(query)
if (!isEmpty(parsingErrors)) {
return {
isValid: false,
query,
parsingErrors
}
}
try {
await esClient.esql.query(
{
query: query,
format: 'json',
},
)
} catch (executionError) {
return {
isValid: false,
query,
executionError
}
}
return {
isValid: true,
query
}
}
const extractErrorMessage = (error: any): string => {
return error.message || `Unknown error`
}