mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
esql self healing graph
This commit is contained in:
parent
ad3dce1179
commit
c76a4469ae
16 changed files with 623 additions and 28 deletions
|
@ -0,0 +1,48 @@
|
|||
import { BaseMessage as LangChainBaseMessage, AIMessage as LangChainAIMessage, ToolMessage as LangChainToolMessage } from '@langchain/core/messages';
|
||||
import { AssistantMessage as InferenceAssistantMessage, Message as InferenceMessage, MessageRole, ToolCall, ToolMessage, UserMessage } from '@kbn/inference-common';
|
||||
import { isEmpty } from 'lodash';
|
||||
|
||||
export const langchainMessageToInferenceMessage = (langChainMessage: LangChainBaseMessage): InferenceMessage => {
|
||||
switch (langChainMessage.getType()) {
|
||||
case "system":
|
||||
case "generic":
|
||||
case "developer":
|
||||
case "human": {
|
||||
return {
|
||||
role: MessageRole.User,
|
||||
content: langChainMessage.content
|
||||
} as UserMessage
|
||||
}
|
||||
case "ai": {
|
||||
if (langChainMessage instanceof LangChainAIMessage) {
|
||||
const toolCalls: ToolCall[] | undefined = langChainMessage.tool_calls?.map((toolCall): ToolCall => ({
|
||||
toolCallId: toolCall.id as string,
|
||||
function: {
|
||||
...(!isEmpty(toolCall.args) ? { arguments: toolCall.args } : {}),
|
||||
name: toolCall.name
|
||||
}
|
||||
}))
|
||||
return {
|
||||
role: MessageRole.Assistant,
|
||||
content: langChainMessage.content as string,
|
||||
...(!isEmpty(toolCalls) ? { toolCalls } : {})
|
||||
} as InferenceAssistantMessage
|
||||
}
|
||||
throw new Error(`Unable to convert LangChain message of type ${langChainMessage.getType()} to Inference message`)
|
||||
}
|
||||
case "tool": {
|
||||
if (langChainMessage instanceof LangChainToolMessage) {
|
||||
return {
|
||||
name: langChainMessage.name,
|
||||
toolCallId: langChainMessage.tool_call_id,
|
||||
role: MessageRole.Tool,
|
||||
response: langChainMessage.content,
|
||||
} as ToolMessage
|
||||
}
|
||||
throw new Error(`Unable to convert LangChain message of type ${langChainMessage.getType()} to Inference message`)
|
||||
}
|
||||
default: {
|
||||
throw new Error(`Unable to convert LangChain message of type ${langChainMessage.getType()} to Inference message`)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
import { ToolDefinition as InferenceToolDefinition, ToolSchema as InferenceToolSchema } from '@kbn/inference-common';
|
||||
import { StructuredToolInterface as LangChainStructuredToolInterface } from '@langchain/core/tools';
|
||||
import { pick } from 'lodash';
|
||||
import { ZodSchema } from 'zod';
|
||||
import zodToJsonSchema from 'zod-to-json-schema';
|
||||
|
||||
export const langchainToolToInferenceTool = (langchainTool: LangChainStructuredToolInterface): InferenceToolDefinition => {
|
||||
const schema = langchainTool.schema ? zodSchemaToInference(langchainTool.schema) : undefined;
|
||||
return {
|
||||
description: langchainTool.description,
|
||||
...(schema ? { schema } : {}),
|
||||
}
|
||||
}
|
||||
|
||||
export const langchainToolsToInferenceTools = (langchainTool: LangChainStructuredToolInterface[]) : Record<string, InferenceToolDefinition> => {
|
||||
return langchainTool.reduce((acc, tool) => {
|
||||
acc[tool.name] = langchainToolToInferenceTool(tool);
|
||||
return acc;
|
||||
}, {} as Record<string, InferenceToolDefinition>);
|
||||
}
|
||||
|
||||
function zodSchemaToInference(schema: ZodSchema): InferenceToolSchema {
|
||||
return pick(zodToJsonSchema(schema), ['type', 'properties', 'required']) as InferenceToolSchema;
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
import { ChatCompletionMessageEvent } from "@kbn/inference-common";
|
||||
import { AIMessage, BaseMessage as LangChainBaseMessage } from '@langchain/core/messages';
|
||||
import { isEmpty } from "lodash";
|
||||
|
||||
|
||||
export const nlToEsqlTaskEventToLangchainMessage = (taskEvent: ChatCompletionMessageEvent): LangChainBaseMessage => {
|
||||
switch (taskEvent.type) {
|
||||
case "chatCompletionMessage": {
|
||||
const toolCalls = taskEvent.toolCalls.map((toolCall) => {
|
||||
return {
|
||||
name: toolCall.function.name,
|
||||
args: toolCall.function.arguments,
|
||||
id: toolCall.toolCallId,
|
||||
type: "tool_call"
|
||||
} as const
|
||||
})
|
||||
return new AIMessage({
|
||||
content: taskEvent.content,
|
||||
...(!isEmpty(toolCalls) ? { tool_calls: toolCalls } : {})
|
||||
})
|
||||
}
|
||||
default: {
|
||||
throw new Error(`Unable to convert nlToEsqlTaskEvent of type ${taskEvent.type} to LangChain message`)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -42,6 +42,14 @@ export type {
|
|||
|
||||
export { transformRawData } from './impl/data_anonymization/transform_raw_data';
|
||||
export { parseBedrockBuffer, handleBedrockChunk } from './impl/utils/bedrock';
|
||||
export { langchainMessageToInferenceMessage } from './impl/utils/langchain_message_to_inference_message';
|
||||
export {
|
||||
langchainToolToInferenceTool,
|
||||
langchainToolsToInferenceTools
|
||||
} from './impl/utils/langchain_tool_to_inference_tool';
|
||||
export {
|
||||
nlToEsqlTaskEventToLangchainMessage
|
||||
} from './impl/utils/nl_to_esql_task_event_to_langchain_message';
|
||||
export * from './constants';
|
||||
|
||||
/** currently the same shape as "fields" property in the ES response */
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
import { AssistantMessage, MessageRole, UserMessage, ToolMessage } from "@kbn/inference-common"
|
||||
import { ensureMultiTurn } from "./ensure_multi_turn"
|
||||
|
||||
const assistantMessage: AssistantMessage = {
|
||||
role: MessageRole.Assistant,
|
||||
content: "Hello from assistant"
|
||||
}
|
||||
|
||||
const userMessage: UserMessage = {
|
||||
role: MessageRole.User,
|
||||
content: "Hello from user"
|
||||
}
|
||||
|
||||
const toolMessage: ToolMessage = {
|
||||
role: MessageRole.Tool,
|
||||
response: "Hello from tool",
|
||||
toolCallId: "123",
|
||||
name: "toolName"
|
||||
}
|
||||
|
||||
const intermediaryUserMessage: UserMessage = {
|
||||
role: MessageRole.User,
|
||||
content: "-"
|
||||
}
|
||||
|
||||
const intemediaryAssistantMessage: AssistantMessage = {
|
||||
role: MessageRole.Assistant,
|
||||
content: "-"
|
||||
}
|
||||
|
||||
describe("ensureMultiTurn", () => {
|
||||
it("returns correct value for message sequence", () => {
|
||||
const messages = [
|
||||
assistantMessage,
|
||||
assistantMessage,
|
||||
userMessage,
|
||||
userMessage,
|
||||
toolMessage,
|
||||
toolMessage,
|
||||
userMessage,
|
||||
assistantMessage,
|
||||
toolMessage]
|
||||
|
||||
const result = ensureMultiTurn(messages)
|
||||
|
||||
expect(result).toEqual([
|
||||
assistantMessage,
|
||||
intermediaryUserMessage,
|
||||
assistantMessage,
|
||||
userMessage,
|
||||
intemediaryAssistantMessage,
|
||||
userMessage,
|
||||
toolMessage,
|
||||
toolMessage,
|
||||
userMessage,
|
||||
assistantMessage,
|
||||
toolMessage
|
||||
])
|
||||
|
||||
})
|
||||
})
|
|
@ -7,20 +7,65 @@
|
|||
|
||||
import { Message, MessageRole } from '@kbn/inference-common';
|
||||
|
||||
function isUserMessage(message: Message): boolean {
|
||||
return message.role !== MessageRole.Assistant;
|
||||
type MessageRoleSequenceResult = {
|
||||
roleSequenceValid: true;
|
||||
} | {
|
||||
roleSequenceValid: false;
|
||||
intermediaryRole: MessageRole.User | MessageRole.Assistant;
|
||||
};
|
||||
|
||||
/**
|
||||
* Two consecutive messages with USER role or two consecutive messages with ASSISTANT role are not allowed.
|
||||
* Consecutive messages with TOOL role are allowed.
|
||||
*/
|
||||
function checkMessageRoleSequenceValid(prevMessage: Message | undefined, message: Message): MessageRoleSequenceResult {
|
||||
if (!prevMessage) {
|
||||
return {
|
||||
roleSequenceValid: true
|
||||
};
|
||||
}
|
||||
|
||||
if (prevMessage.role === MessageRole.User && message.role === MessageRole.User) {
|
||||
return {
|
||||
roleSequenceValid: false,
|
||||
intermediaryRole: MessageRole.Assistant
|
||||
};
|
||||
}
|
||||
if (prevMessage.role === MessageRole.Assistant && message.role === MessageRole.Assistant) {
|
||||
return {
|
||||
roleSequenceValid: false,
|
||||
intermediaryRole: MessageRole.User
|
||||
};
|
||||
}
|
||||
if (prevMessage.role === MessageRole.Tool && message.role === MessageRole.Tool) {
|
||||
return {
|
||||
roleSequenceValid: true
|
||||
};
|
||||
}
|
||||
if (prevMessage.role === message.role) {
|
||||
return {
|
||||
roleSequenceValid: false,
|
||||
intermediaryRole: MessageRole.Assistant
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
roleSequenceValid: true
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
export function ensureMultiTurn(messages: Message[]): Message[] {
|
||||
const next: Message[] = [];
|
||||
|
||||
messages.forEach((message) => {
|
||||
const prevMessage = next[next.length - 1];
|
||||
|
||||
if (prevMessage && isUserMessage(prevMessage) === isUserMessage(message)) {
|
||||
const result = checkMessageRoleSequenceValid(prevMessage, message);
|
||||
if (!result.roleSequenceValid) {
|
||||
next.push({
|
||||
content: '-',
|
||||
role: isUserMessage(message) ? MessageRole.Assistant : MessageRole.User,
|
||||
role: result.intermediaryRole,
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
export const NL_TO_ESQL_AGENT_NODE = "nl_to_esql_agent";
|
||||
export const TOOLS_NODE = "tools";
|
||||
export const ESQL_VALIDATOR_NODE = "esql_validator";
|
|
@ -0,0 +1,68 @@
|
|||
import { StructuredToolInterface } from "@langchain/core/tools";
|
||||
import { getIndexNamesTool } from "./index_names_tool";
|
||||
import { ElasticsearchClient, KibanaRequest } from "@kbn/core/server";
|
||||
import { ToolNode } from "@langchain/langgraph/prebuilt";
|
||||
import { END, START, StateGraph } from "@langchain/langgraph";
|
||||
import { EsqlSelfHealingAnnotation } from "./state";
|
||||
import { ESQL_VALIDATOR_NODE, NL_TO_ESQL_AGENT_NODE, TOOLS_NODE } from "./constants";
|
||||
import { stepRouter } from "./step_router";
|
||||
import { getNlToEsqlAgent } from "./nl_to_esql_agent";
|
||||
import { InferenceServerStart } from "@kbn/inference-plugin/server";
|
||||
import { Logger } from '@kbn/core/server';
|
||||
import { getValidatorNode } from "./validator";
|
||||
import { getInspectIndexMappingTool } from "./inspect_index_mapping_tool";
|
||||
|
||||
export const getEsqlSelfHealingGraph = ({
|
||||
esClient,
|
||||
connectorId,
|
||||
inference,
|
||||
logger,
|
||||
request,
|
||||
}: {
|
||||
esClient: ElasticsearchClient
|
||||
connectorId: string,
|
||||
inference: InferenceServerStart,
|
||||
logger: Logger,
|
||||
request: KibanaRequest,
|
||||
}
|
||||
) => {
|
||||
|
||||
const availableIndexNamesTool = getIndexNamesTool({
|
||||
esClient
|
||||
})
|
||||
const inspectIndexMappingTool = getInspectIndexMappingTool({
|
||||
esClient
|
||||
})
|
||||
|
||||
const tools: StructuredToolInterface[] = [
|
||||
availableIndexNamesTool,
|
||||
inspectIndexMappingTool
|
||||
]
|
||||
|
||||
const toolNode = new ToolNode(tools)
|
||||
const nlToEsqlAgentNode = getNlToEsqlAgent({
|
||||
connectorId,
|
||||
inference,
|
||||
logger,
|
||||
request,
|
||||
tools,
|
||||
})
|
||||
const validatorNode = getValidatorNode({
|
||||
esClient
|
||||
})
|
||||
|
||||
const graph = new StateGraph(EsqlSelfHealingAnnotation)
|
||||
.addNode(NL_TO_ESQL_AGENT_NODE, nlToEsqlAgentNode)
|
||||
.addNode(TOOLS_NODE, toolNode)
|
||||
.addNode(ESQL_VALIDATOR_NODE, validatorNode, {
|
||||
ends: [END, NL_TO_ESQL_AGENT_NODE]
|
||||
})
|
||||
.addEdge(START, NL_TO_ESQL_AGENT_NODE)
|
||||
.addEdge(TOOLS_NODE, NL_TO_ESQL_AGENT_NODE)
|
||||
.addConditionalEdges(NL_TO_ESQL_AGENT_NODE, stepRouter, {
|
||||
[TOOLS_NODE]: TOOLS_NODE,
|
||||
[ESQL_VALIDATOR_NODE]: ESQL_VALIDATOR_NODE
|
||||
}).compile()
|
||||
|
||||
return graph
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
import { ElasticsearchClient } from "@kbn/core/server";
|
||||
import { tool } from "@langchain/core/tools";
|
||||
|
||||
const toolDetails = {
|
||||
name: "available_index_names",
|
||||
description: "Get the available indices in the elastic search cluster. Use this when there is an unknown index error or you need to get the indeces that can be queried. Using the response select an appropriate index name."
|
||||
}
|
||||
|
||||
export const getIndexNamesTool = ({
|
||||
esClient
|
||||
}: {
|
||||
esClient: ElasticsearchClient
|
||||
}) => {
|
||||
return tool(async () => {
|
||||
const indexNames = await esClient.cat.indices({
|
||||
format: 'json'
|
||||
}).then((response) => response
|
||||
.map((index) => index.index)
|
||||
.filter((index) => index!=undefined)
|
||||
.toSorted()
|
||||
)
|
||||
return `These are the names of the available indeces. To query them, you must use the full index name verbatim.\n\n${indexNames.join('\n')}`
|
||||
}, {
|
||||
name: toolDetails.name,
|
||||
description: toolDetails.description
|
||||
})
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
import { ElasticsearchClient } from "@kbn/core/server"
|
||||
import { tool } from "@langchain/core/tools"
|
||||
import { z } from "zod"
|
||||
|
||||
const toolDetails = {
|
||||
name: "inspect_index_mapping",
|
||||
description: `Use this tool to inspect an index mapping. The tool with fetch the mappings of the provided indexName and then return the data at the given propertyKey. For example:
|
||||
|
||||
Example index mapping:
|
||||
\`\`\`
|
||||
{
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"field1": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"field2": {
|
||||
"properties": {
|
||||
"nested_field": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
\`\`\`
|
||||
|
||||
Input:
|
||||
\`\`\`
|
||||
{
|
||||
"indexName": "my_index",
|
||||
"propertyKey": "mappings.properties"
|
||||
}
|
||||
\`\`\`
|
||||
|
||||
Output:
|
||||
\`\`\`
|
||||
{
|
||||
"field1": "Object",
|
||||
"field2": "Object"
|
||||
}
|
||||
\`\`\
|
||||
|
||||
The tool can be called repeatedly to explode objects and arrays. For example:
|
||||
|
||||
Input:
|
||||
\`\`\`
|
||||
{
|
||||
|
||||
"indexName": "my_index",
|
||||
"propertyKey": "mappings.properties.field1"
|
||||
}
|
||||
\`\`\`
|
||||
|
||||
Output:
|
||||
\`\`\`
|
||||
{
|
||||
"type": "keyword",
|
||||
}
|
||||
\`\`\``
|
||||
}
|
||||
|
||||
export const getInspectIndexMappingTool = ({
|
||||
esClient
|
||||
}: {
|
||||
esClient: ElasticsearchClient
|
||||
}) => {
|
||||
return tool(async ({
|
||||
indexName,
|
||||
propertyKey
|
||||
}) => {
|
||||
const indexMapping = await esClient.indices.getMapping({
|
||||
index: indexName
|
||||
})
|
||||
|
||||
const entriesAtKey = getEntriesAtKey(indexMapping[indexName], propertyKey.split("."))
|
||||
const result = formatEntriesAtKey(entriesAtKey)
|
||||
|
||||
return `Object at ${propertyKey} \n${JSON.stringify(result, null, 2)}`
|
||||
}, {
|
||||
name: toolDetails.name,
|
||||
description: toolDetails.description,
|
||||
schema: z.object({
|
||||
indexName: z.string().describe(`The index name to get the properties of.`),
|
||||
propertyKey: z.string().optional().default("mappings.properties").describe(`The key to get the properties of.`)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
const getEntriesAtKey = (mapping: Record<string, any> | undefined, keys: string[]): Record<string, any> | undefined => {
|
||||
if (mapping === undefined) {
|
||||
return
|
||||
}
|
||||
if (keys.length === 0) {
|
||||
return mapping
|
||||
}
|
||||
|
||||
const key = keys.shift()
|
||||
if (key === undefined) {
|
||||
return mapping
|
||||
}
|
||||
|
||||
return getEntriesAtKey(mapping[key], keys)
|
||||
}
|
||||
|
||||
const formatEntriesAtKey = (mapping: Record<string, any> | undefined): Record<string, string> => {
|
||||
if (mapping === undefined) {
|
||||
return {}
|
||||
}
|
||||
return Object.entries(mapping).reduce((acc, [key, value]) => {
|
||||
acc[key] = typeof value === "string" ? value : "Object"
|
||||
return acc
|
||||
}, {} as Record<string, string>)
|
||||
}
|
|
@ -1,5 +0,0 @@
|
|||
describe("natural_lanagauge_to_esql_validator", ()=>{
|
||||
it("test1", ()=>{
|
||||
|
||||
})
|
||||
})
|
|
@ -0,0 +1,44 @@
|
|||
import { lastValueFrom } from "rxjs";
|
||||
import { EsqlSelfHealingAnnotation } from "./state"
|
||||
import { langchainMessageToInferenceMessage, langchainToolsToInferenceTools, nlToEsqlTaskEventToLangchainMessage } from "@kbn/elastic-assistant-common";
|
||||
import { StructuredToolInterface } from "@langchain/core/tools";
|
||||
import { KibanaRequest } from "@kbn/core/server";
|
||||
import { InferenceServerStart, naturalLanguageToEsql } from "@kbn/inference-plugin/server";
|
||||
import { Logger } from '@kbn/core/server';
|
||||
import { ChatCompletionMessageEvent } from "@kbn/inference-common";
|
||||
import { Command } from "@langchain/langgraph";
|
||||
|
||||
export const getNlToEsqlAgent = ({
|
||||
connectorId,
|
||||
inference,
|
||||
logger,
|
||||
request,
|
||||
tools,
|
||||
}:{
|
||||
connectorId: string,
|
||||
inference: InferenceServerStart,
|
||||
logger: Logger,
|
||||
request: KibanaRequest,
|
||||
tools: StructuredToolInterface[]
|
||||
}) => {
|
||||
return async (state: typeof EsqlSelfHealingAnnotation.State) => {
|
||||
const { messages: stateMessages } = state;
|
||||
|
||||
const inferenceMessages = stateMessages.map(langchainMessageToInferenceMessage)
|
||||
|
||||
const result = await lastValueFrom(naturalLanguageToEsql({
|
||||
client: inference.getClient({ request: request }),
|
||||
connectorId: connectorId,
|
||||
functionCalling: 'auto',
|
||||
logger: logger,
|
||||
tools: langchainToolsToInferenceTools(tools),
|
||||
messages: inferenceMessages
|
||||
})) as ChatCompletionMessageEvent
|
||||
|
||||
return new Command({
|
||||
update:{
|
||||
messages : [nlToEsqlTaskEventToLangchainMessage(result)]
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -7,17 +7,11 @@
|
|||
|
||||
import { tool } from '@langchain/core/tools';
|
||||
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
|
||||
import { lastValueFrom } from 'rxjs';
|
||||
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
|
||||
import { z } from '@kbn/zod';
|
||||
import { APP_UI_ID } from '../../../../common';
|
||||
import { getEsqlFromContent, getPromptSuffixForOssModel } from './common';
|
||||
import { NlToEsqlTaskEvent } from '@kbn/inference-plugin/server/tasks/nl_to_esql';
|
||||
import { ToolOptions } from '@kbn/inference-common';
|
||||
import { parse } from '@kbn/esql-ast';
|
||||
import isEmpty from 'lodash/isEmpty';
|
||||
import { ElasticsearchClient } from '@kbn/core/server';
|
||||
import { NaturalLanguageToEsqlValidator } from './natual_language_to_esql_validator';
|
||||
import { getPromptSuffixForOssModel } from './common';
|
||||
import { HumanMessage } from '@langchain/core/messages';
|
||||
import { getEsqlSelfHealingGraph } from './graph';
|
||||
|
||||
// select only some properties of AssistantToolParams
|
||||
export type ESQLToolParams = AssistantToolParams;
|
||||
|
@ -51,22 +45,21 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
|
|||
const { connectorId, inference, logger, request, isOssModel, esClient } = params as ESQLToolParams;
|
||||
if (inference == null || connectorId == null) return null;
|
||||
|
||||
const naturalLanguageToEsqlValidator = new NaturalLanguageToEsqlValidator({
|
||||
inference,
|
||||
const selfHealingGraph = getEsqlSelfHealingGraph({
|
||||
esClient,
|
||||
connectorId,
|
||||
inference,
|
||||
logger,
|
||||
request,
|
||||
esClient,
|
||||
})
|
||||
|
||||
return tool(
|
||||
async (input) => {
|
||||
|
||||
const answer = await naturalLanguageToEsqlValidator.generateEsqlFromNaturalLanguage(input.question);
|
||||
|
||||
console.log(`Received response from NL to ESQL tool: ${answer}`)
|
||||
logger.debug(`Received response from NL to ESQL tool: ${answer}`);
|
||||
return answer;
|
||||
async ({ question }) => {
|
||||
const humanMessage = new HumanMessage({ content: question })
|
||||
const result = await selfHealingGraph.invoke({ messages: [humanMessage] })
|
||||
const { messages } = result
|
||||
const lastMessage = messages[messages.length - 1]
|
||||
return lastMessage.content
|
||||
},
|
||||
{
|
||||
name: toolDetails.name,
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
import { MessagesAnnotation } from "@langchain/langgraph";
|
||||
|
||||
export const EsqlSelfHealingAnnotation = MessagesAnnotation
|
|
@ -0,0 +1,16 @@
|
|||
import { EsqlSelfHealingAnnotation } from "./state";
|
||||
import { ESQL_VALIDATOR_NODE, TOOLS_NODE } from "./constants";
|
||||
|
||||
export const stepRouter = (state: typeof EsqlSelfHealingAnnotation.State): string => {
|
||||
const { messages } = state;
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
if (
|
||||
'tool_calls' in lastMessage &&
|
||||
Array.isArray(lastMessage.tool_calls) &&
|
||||
lastMessage.tool_calls?.length
|
||||
) {
|
||||
return TOOLS_NODE;
|
||||
}
|
||||
|
||||
return ESQL_VALIDATOR_NODE
|
||||
}
|
|
@ -0,0 +1,118 @@
|
|||
import { ElasticsearchClient } from "@kbn/core/server"
|
||||
import { EsqlSelfHealingAnnotation } from "./state"
|
||||
import { getEsqlFromContent } from "./common"
|
||||
import { Command, END } from "@langchain/langgraph"
|
||||
import { EditorError, parse } from "@kbn/esql-ast"
|
||||
import { isEmpty } from "lodash"
|
||||
import { BaseMessage, HumanMessage } from "@langchain/core/messages"
|
||||
import { NL_TO_ESQL_AGENT_NODE } from "./constants"
|
||||
|
||||
type ValidateEsqlResult = {
|
||||
isValid: boolean
|
||||
query: string
|
||||
parsingErrors?: EditorError[]
|
||||
executionError?: any
|
||||
}
|
||||
|
||||
export const getValidatorNode = ({
|
||||
esClient
|
||||
}: {
|
||||
esClient: ElasticsearchClient
|
||||
}) => {
|
||||
return async (state: typeof EsqlSelfHealingAnnotation.State) => {
|
||||
const { messages } = state;
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
|
||||
const generatedQueries = getEsqlFromContent(lastMessage.content as string)
|
||||
|
||||
if (!generatedQueries.length) {
|
||||
return new Command({
|
||||
goto: END
|
||||
})
|
||||
}
|
||||
|
||||
const validateEsqlResults = await Promise.all(
|
||||
generatedQueries.map(query => validateEsql(esClient, query))
|
||||
)
|
||||
|
||||
const containsInvalidQueries = validateEsqlResults.some(result => !result.isValid)
|
||||
|
||||
if (containsInvalidQueries) {
|
||||
return new Command({
|
||||
update: {
|
||||
messages: [lastMessageWithErrorReport(lastMessage.content as string, validateEsqlResults)]
|
||||
},
|
||||
goto: NL_TO_ESQL_AGENT_NODE
|
||||
})
|
||||
}
|
||||
|
||||
return new Command({
|
||||
goto: END,
|
||||
update:{
|
||||
messages: `${lastMessage.content} \nAll of the queries have been validated and do not need to modified futher.`
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const lastMessageWithErrorReport = (message: string, validateEsqlResults: ValidateEsqlResult[]): BaseMessage => {
|
||||
let messageWithErrorReport = message
|
||||
validateEsqlResults.reverse().forEach((validateEsqlResult) => {
|
||||
const index = messageWithErrorReport.indexOf('```', messageWithErrorReport.indexOf(validateEsqlResult.query))
|
||||
const errorMessage = formatValidateEsqlResultToHumanReadable(validateEsqlResult)
|
||||
messageWithErrorReport = `${messageWithErrorReport.slice(0, index+3)}\n${errorMessage}\n${messageWithErrorReport.slice(index+3)}`
|
||||
})
|
||||
|
||||
return new HumanMessage({
|
||||
content: messageWithErrorReport
|
||||
})
|
||||
}
|
||||
|
||||
const formatValidateEsqlResultToHumanReadable = (validateEsqlResult: ValidateEsqlResult) => {
|
||||
if (validateEsqlResult.isValid) {
|
||||
return "Query is valid"
|
||||
}
|
||||
let errorMessage = "This query has errors:\n"
|
||||
if (validateEsqlResult.parsingErrors) {
|
||||
errorMessage += `${validateEsqlResult.parsingErrors.map((error) => error.message).join('\n')}\n`
|
||||
}
|
||||
if (validateEsqlResult.executionError) {
|
||||
errorMessage += `${extractErrorMessage(validateEsqlResult.executionError)}\n`
|
||||
}
|
||||
return errorMessage
|
||||
}
|
||||
|
||||
const validateEsql = async (esClient: ElasticsearchClient, query: string): Promise<ValidateEsqlResult> => {
|
||||
const { errors: parsingErrors } = parse(query)
|
||||
if (!isEmpty(parsingErrors)) {
|
||||
return {
|
||||
isValid: false,
|
||||
query,
|
||||
parsingErrors
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
await esClient.esql.query(
|
||||
{
|
||||
query: query,
|
||||
format: 'json',
|
||||
},
|
||||
)
|
||||
} catch (executionError) {
|
||||
return {
|
||||
isValid: false,
|
||||
query,
|
||||
executionError
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: true,
|
||||
query
|
||||
}
|
||||
}
|
||||
|
||||
const extractErrorMessage = (error: any): string => {
|
||||
return error.message || `Unknown error`
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue