mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 01:38:56 -04:00
[CI] Auto-commit changed files from 'node scripts/eslint --no-cache --fix'
This commit is contained in:
parent
ed5e4fed7f
commit
355b859fcd
18 changed files with 757 additions and 547 deletions
|
@ -1,48 +1,76 @@
|
|||
import { BaseMessage as LangChainBaseMessage, AIMessage as LangChainAIMessage, ToolMessage as LangChainToolMessage } from '@langchain/core/messages';
|
||||
import { AssistantMessage as InferenceAssistantMessage, Message as InferenceMessage, MessageRole, ToolCall, ToolMessage, UserMessage } from '@kbn/inference-common';
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import {
|
||||
BaseMessage as LangChainBaseMessage,
|
||||
AIMessage as LangChainAIMessage,
|
||||
ToolMessage as LangChainToolMessage,
|
||||
} from '@langchain/core/messages';
|
||||
import {
|
||||
AssistantMessage as InferenceAssistantMessage,
|
||||
Message as InferenceMessage,
|
||||
MessageRole,
|
||||
ToolCall,
|
||||
ToolMessage,
|
||||
UserMessage,
|
||||
} from '@kbn/inference-common';
|
||||
import { isEmpty } from 'lodash';
|
||||
|
||||
export const langchainMessageToInferenceMessage = (langChainMessage: LangChainBaseMessage): InferenceMessage => {
|
||||
switch (langChainMessage.getType()) {
|
||||
case "system":
|
||||
case "generic":
|
||||
case "developer":
|
||||
case "human": {
|
||||
return {
|
||||
role: MessageRole.User,
|
||||
content: langChainMessage.content
|
||||
} as UserMessage
|
||||
}
|
||||
case "ai": {
|
||||
if (langChainMessage instanceof LangChainAIMessage) {
|
||||
const toolCalls: ToolCall[] | undefined = langChainMessage.tool_calls?.map((toolCall): ToolCall => ({
|
||||
toolCallId: toolCall.id as string,
|
||||
function: {
|
||||
...(!isEmpty(toolCall.args) ? { arguments: toolCall.args } : {}),
|
||||
name: toolCall.name
|
||||
}
|
||||
}))
|
||||
return {
|
||||
role: MessageRole.Assistant,
|
||||
content: langChainMessage.content as string,
|
||||
...(!isEmpty(toolCalls) ? { toolCalls } : {})
|
||||
} as InferenceAssistantMessage
|
||||
}
|
||||
throw new Error(`Unable to convert LangChain message of type ${langChainMessage.getType()} to Inference message`)
|
||||
}
|
||||
case "tool": {
|
||||
if (langChainMessage instanceof LangChainToolMessage) {
|
||||
return {
|
||||
name: langChainMessage.name,
|
||||
toolCallId: langChainMessage.tool_call_id,
|
||||
role: MessageRole.Tool,
|
||||
response: langChainMessage.content,
|
||||
} as ToolMessage
|
||||
}
|
||||
throw new Error(`Unable to convert LangChain message of type ${langChainMessage.getType()} to Inference message`)
|
||||
}
|
||||
default: {
|
||||
throw new Error(`Unable to convert LangChain message of type ${langChainMessage.getType()} to Inference message`)
|
||||
}
|
||||
export const langchainMessageToInferenceMessage = (
|
||||
langChainMessage: LangChainBaseMessage
|
||||
): InferenceMessage => {
|
||||
switch (langChainMessage.getType()) {
|
||||
case 'system':
|
||||
case 'generic':
|
||||
case 'developer':
|
||||
case 'human': {
|
||||
return {
|
||||
role: MessageRole.User,
|
||||
content: langChainMessage.content,
|
||||
} as UserMessage;
|
||||
}
|
||||
}
|
||||
case 'ai': {
|
||||
if (langChainMessage instanceof LangChainAIMessage) {
|
||||
const toolCalls: ToolCall[] | undefined = langChainMessage.tool_calls?.map(
|
||||
(toolCall): ToolCall => ({
|
||||
toolCallId: toolCall.id as string,
|
||||
function: {
|
||||
...(!isEmpty(toolCall.args) ? { arguments: toolCall.args } : {}),
|
||||
name: toolCall.name,
|
||||
},
|
||||
})
|
||||
);
|
||||
return {
|
||||
role: MessageRole.Assistant,
|
||||
content: langChainMessage.content as string,
|
||||
...(!isEmpty(toolCalls) ? { toolCalls } : {}),
|
||||
} as InferenceAssistantMessage;
|
||||
}
|
||||
throw new Error(
|
||||
`Unable to convert LangChain message of type ${langChainMessage.getType()} to Inference message`
|
||||
);
|
||||
}
|
||||
case 'tool': {
|
||||
if (langChainMessage instanceof LangChainToolMessage) {
|
||||
return {
|
||||
name: langChainMessage.name,
|
||||
toolCallId: langChainMessage.tool_call_id,
|
||||
role: MessageRole.Tool,
|
||||
response: langChainMessage.content,
|
||||
} as ToolMessage;
|
||||
}
|
||||
throw new Error(
|
||||
`Unable to convert LangChain message of type ${langChainMessage.getType()} to Inference message`
|
||||
);
|
||||
}
|
||||
default: {
|
||||
throw new Error(
|
||||
`Unable to convert LangChain message of type ${langChainMessage.getType()} to Inference message`
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,24 +1,38 @@
|
|||
import { ToolDefinition as InferenceToolDefinition, ToolSchema as InferenceToolSchema } from '@kbn/inference-common';
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import {
|
||||
ToolDefinition as InferenceToolDefinition,
|
||||
ToolSchema as InferenceToolSchema,
|
||||
} from '@kbn/inference-common';
|
||||
import { StructuredToolInterface as LangChainStructuredToolInterface } from '@langchain/core/tools';
|
||||
import { pick } from 'lodash';
|
||||
import { ZodSchema } from 'zod';
|
||||
import { ZodSchema } from '@kbn/zod';
|
||||
import zodToJsonSchema from 'zod-to-json-schema';
|
||||
|
||||
export const langchainToolToInferenceTool = (langchainTool: LangChainStructuredToolInterface): InferenceToolDefinition => {
|
||||
const schema = langchainTool.schema ? zodSchemaToInference(langchainTool.schema) : undefined;
|
||||
return {
|
||||
description: langchainTool.description,
|
||||
...(schema ? { schema } : {}),
|
||||
}
|
||||
}
|
||||
export const langchainToolToInferenceTool = (
|
||||
langchainTool: LangChainStructuredToolInterface
|
||||
): InferenceToolDefinition => {
|
||||
const schema = langchainTool.schema ? zodSchemaToInference(langchainTool.schema) : undefined;
|
||||
return {
|
||||
description: langchainTool.description,
|
||||
...(schema ? { schema } : {}),
|
||||
};
|
||||
};
|
||||
|
||||
export const langchainToolsToInferenceTools = (langchainTool: LangChainStructuredToolInterface[]) : Record<string, InferenceToolDefinition> => {
|
||||
return langchainTool.reduce((acc, tool) => {
|
||||
acc[tool.name] = langchainToolToInferenceTool(tool);
|
||||
return acc;
|
||||
}, {} as Record<string, InferenceToolDefinition>);
|
||||
}
|
||||
export const langchainToolsToInferenceTools = (
|
||||
langchainTool: LangChainStructuredToolInterface[]
|
||||
): Record<string, InferenceToolDefinition> => {
|
||||
return langchainTool.reduce((acc, tool) => {
|
||||
acc[tool.name] = langchainToolToInferenceTool(tool);
|
||||
return acc;
|
||||
}, {} as Record<string, InferenceToolDefinition>);
|
||||
};
|
||||
|
||||
function zodSchemaToInference(schema: ZodSchema): InferenceToolSchema {
|
||||
return pick(zodToJsonSchema(schema), ['type', 'properties', 'required']) as InferenceToolSchema;
|
||||
}
|
||||
return pick(zodToJsonSchema(schema), ['type', 'properties', 'required']) as InferenceToolSchema;
|
||||
}
|
||||
|
|
|
@ -1,26 +1,36 @@
|
|||
import { ChatCompletionMessageEvent } from "@kbn/inference-common";
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { ChatCompletionMessageEvent } from '@kbn/inference-common';
|
||||
import { AIMessage, BaseMessage as LangChainBaseMessage } from '@langchain/core/messages';
|
||||
import { isEmpty } from "lodash";
|
||||
import { isEmpty } from 'lodash';
|
||||
|
||||
|
||||
export const nlToEsqlTaskEventToLangchainMessage = (taskEvent: ChatCompletionMessageEvent): LangChainBaseMessage => {
|
||||
switch (taskEvent.type) {
|
||||
case "chatCompletionMessage": {
|
||||
const toolCalls = taskEvent.toolCalls.map((toolCall) => {
|
||||
return {
|
||||
name: toolCall.function.name,
|
||||
args: toolCall.function.arguments,
|
||||
id: toolCall.toolCallId,
|
||||
type: "tool_call"
|
||||
} as const
|
||||
})
|
||||
return new AIMessage({
|
||||
content: taskEvent.content,
|
||||
...(!isEmpty(toolCalls) ? { tool_calls: toolCalls } : {})
|
||||
})
|
||||
}
|
||||
default: {
|
||||
throw new Error(`Unable to convert nlToEsqlTaskEvent of type ${taskEvent.type} to LangChain message`)
|
||||
}
|
||||
export const nlToEsqlTaskEventToLangchainMessage = (
|
||||
taskEvent: ChatCompletionMessageEvent
|
||||
): LangChainBaseMessage => {
|
||||
switch (taskEvent.type) {
|
||||
case 'chatCompletionMessage': {
|
||||
const toolCalls = taskEvent.toolCalls.map((toolCall) => {
|
||||
return {
|
||||
name: toolCall.function.name,
|
||||
args: toolCall.function.arguments,
|
||||
id: toolCall.toolCallId,
|
||||
type: 'tool_call',
|
||||
} as const;
|
||||
});
|
||||
return new AIMessage({
|
||||
content: taskEvent.content,
|
||||
...(!isEmpty(toolCalls) ? { tool_calls: toolCalls } : {}),
|
||||
});
|
||||
}
|
||||
}
|
||||
default: {
|
||||
throw new Error(
|
||||
`Unable to convert nlToEsqlTaskEvent of type ${taskEvent.type} to LangChain message`
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -45,11 +45,9 @@ export { parseBedrockBuffer, handleBedrockChunk } from './impl/utils/bedrock';
|
|||
export { langchainMessageToInferenceMessage } from './impl/utils/langchain_message_to_inference_message';
|
||||
export {
|
||||
langchainToolToInferenceTool,
|
||||
langchainToolsToInferenceTools
|
||||
langchainToolsToInferenceTools,
|
||||
} from './impl/utils/langchain_tool_to_inference_tool';
|
||||
export {
|
||||
nlToEsqlTaskEventToLangchainMessage
|
||||
} from './impl/utils/nl_to_esql_task_event_to_langchain_message';
|
||||
export { nlToEsqlTaskEventToLangchainMessage } from './impl/utils/nl_to_esql_task_event_to_langchain_message';
|
||||
export * from './constants';
|
||||
|
||||
/** currently the same shape as "fields" property in the ES response */
|
||||
|
|
|
@ -1,61 +1,68 @@
|
|||
import { AssistantMessage, MessageRole, UserMessage, ToolMessage } from "@kbn/inference-common"
|
||||
import { ensureMultiTurn } from "./ensure_multi_turn"
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { AssistantMessage, MessageRole, UserMessage, ToolMessage } from '@kbn/inference-common';
|
||||
import { ensureMultiTurn } from './ensure_multi_turn';
|
||||
|
||||
const assistantMessage: AssistantMessage = {
|
||||
role: MessageRole.Assistant,
|
||||
content: "Hello from assistant"
|
||||
}
|
||||
role: MessageRole.Assistant,
|
||||
content: 'Hello from assistant',
|
||||
};
|
||||
|
||||
const userMessage: UserMessage = {
|
||||
role: MessageRole.User,
|
||||
content: "Hello from user"
|
||||
}
|
||||
role: MessageRole.User,
|
||||
content: 'Hello from user',
|
||||
};
|
||||
|
||||
const toolMessage: ToolMessage = {
|
||||
role: MessageRole.Tool,
|
||||
response: "Hello from tool",
|
||||
toolCallId: "123",
|
||||
name: "toolName"
|
||||
}
|
||||
role: MessageRole.Tool,
|
||||
response: 'Hello from tool',
|
||||
toolCallId: '123',
|
||||
name: 'toolName',
|
||||
};
|
||||
|
||||
const intermediaryUserMessage: UserMessage = {
|
||||
role: MessageRole.User,
|
||||
content: "-"
|
||||
}
|
||||
role: MessageRole.User,
|
||||
content: '-',
|
||||
};
|
||||
|
||||
const intemediaryAssistantMessage: AssistantMessage = {
|
||||
role: MessageRole.Assistant,
|
||||
content: "-"
|
||||
}
|
||||
role: MessageRole.Assistant,
|
||||
content: '-',
|
||||
};
|
||||
|
||||
describe("ensureMultiTurn", () => {
|
||||
it("returns correct value for message sequence", () => {
|
||||
const messages = [
|
||||
assistantMessage,
|
||||
assistantMessage,
|
||||
userMessage,
|
||||
userMessage,
|
||||
toolMessage,
|
||||
toolMessage,
|
||||
userMessage,
|
||||
assistantMessage,
|
||||
toolMessage]
|
||||
describe('ensureMultiTurn', () => {
|
||||
it('returns correct value for message sequence', () => {
|
||||
const messages = [
|
||||
assistantMessage,
|
||||
assistantMessage,
|
||||
userMessage,
|
||||
userMessage,
|
||||
toolMessage,
|
||||
toolMessage,
|
||||
userMessage,
|
||||
assistantMessage,
|
||||
toolMessage,
|
||||
];
|
||||
|
||||
const result = ensureMultiTurn(messages)
|
||||
const result = ensureMultiTurn(messages);
|
||||
|
||||
expect(result).toEqual([
|
||||
assistantMessage,
|
||||
intermediaryUserMessage,
|
||||
assistantMessage,
|
||||
userMessage,
|
||||
intemediaryAssistantMessage,
|
||||
userMessage,
|
||||
toolMessage,
|
||||
toolMessage,
|
||||
userMessage,
|
||||
assistantMessage,
|
||||
toolMessage
|
||||
])
|
||||
|
||||
})
|
||||
})
|
||||
expect(result).toEqual([
|
||||
assistantMessage,
|
||||
intermediaryUserMessage,
|
||||
assistantMessage,
|
||||
userMessage,
|
||||
intemediaryAssistantMessage,
|
||||
userMessage,
|
||||
toolMessage,
|
||||
toolMessage,
|
||||
userMessage,
|
||||
assistantMessage,
|
||||
toolMessage,
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -7,54 +7,58 @@
|
|||
|
||||
import { Message, MessageRole } from '@kbn/inference-common';
|
||||
|
||||
type MessageRoleSequenceResult = {
|
||||
roleSequenceValid: true;
|
||||
} | {
|
||||
roleSequenceValid: false;
|
||||
intermediaryRole: MessageRole.User | MessageRole.Assistant;
|
||||
};
|
||||
type MessageRoleSequenceResult =
|
||||
| {
|
||||
roleSequenceValid: true;
|
||||
}
|
||||
| {
|
||||
roleSequenceValid: false;
|
||||
intermediaryRole: MessageRole.User | MessageRole.Assistant;
|
||||
};
|
||||
|
||||
/**
|
||||
* Two consecutive messages with USER role or two consecutive messages with ASSISTANT role are not allowed.
|
||||
* Consecutive messages with TOOL role are allowed.
|
||||
*/
|
||||
function checkMessageRoleSequenceValid(prevMessage: Message | undefined, message: Message): MessageRoleSequenceResult {
|
||||
function checkMessageRoleSequenceValid(
|
||||
prevMessage: Message | undefined,
|
||||
message: Message
|
||||
): MessageRoleSequenceResult {
|
||||
if (!prevMessage) {
|
||||
return {
|
||||
roleSequenceValid: true
|
||||
roleSequenceValid: true,
|
||||
};
|
||||
}
|
||||
|
||||
if (prevMessage.role === MessageRole.User && message.role === MessageRole.User) {
|
||||
return {
|
||||
roleSequenceValid: false,
|
||||
intermediaryRole: MessageRole.Assistant
|
||||
intermediaryRole: MessageRole.Assistant,
|
||||
};
|
||||
}
|
||||
if (prevMessage.role === MessageRole.Assistant && message.role === MessageRole.Assistant) {
|
||||
return {
|
||||
roleSequenceValid: false,
|
||||
intermediaryRole: MessageRole.User
|
||||
intermediaryRole: MessageRole.User,
|
||||
};
|
||||
}
|
||||
if (prevMessage.role === MessageRole.Tool && message.role === MessageRole.Tool) {
|
||||
return {
|
||||
roleSequenceValid: true
|
||||
roleSequenceValid: true,
|
||||
};
|
||||
}
|
||||
if (prevMessage.role === message.role) {
|
||||
return {
|
||||
roleSequenceValid: false,
|
||||
intermediaryRole: MessageRole.Assistant
|
||||
}
|
||||
intermediaryRole: MessageRole.Assistant,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
roleSequenceValid: true
|
||||
}
|
||||
roleSequenceValid: true,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
export function ensureMultiTurn(messages: Message[]): Message[] {
|
||||
const next: Message[] = [];
|
||||
|
||||
|
|
|
@ -1,10 +1,17 @@
|
|||
import { getEsqlFromContent } from "./common"
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
describe("common", () => {
|
||||
it.each([
|
||||
["```esqlhelloworld```", ['helloworld']],
|
||||
["```esqlhelloworld``````esqlhelloworld```", ['helloworld','helloworld']]
|
||||
])('should add %s and %s', (input: string, expectedResult: string[]) => {
|
||||
expect(getEsqlFromContent(input)).toEqual(expectedResult)
|
||||
})
|
||||
})
|
||||
import { getEsqlFromContent } from './common';
|
||||
|
||||
describe('common', () => {
|
||||
it.each([
|
||||
['```esqlhelloworld```', ['helloworld']],
|
||||
['```esqlhelloworld``````esqlhelloworld```', ['helloworld', 'helloworld']],
|
||||
])('should add %s and %s', (input: string, expectedResult: string[]) => {
|
||||
expect(getEsqlFromContent(input)).toEqual(expectedResult);
|
||||
});
|
||||
});
|
||||
|
|
|
@ -14,20 +14,20 @@ export const getPromptSuffixForOssModel = (toolName: string) => `
|
|||
|
||||
It is important that ES|QL query is preceeded by a new line.`;
|
||||
|
||||
export const getEsqlFromContent = (content: string) : string[] => {
|
||||
const extractedEsql = []
|
||||
let index = 0
|
||||
export const getEsqlFromContent = (content: string): string[] => {
|
||||
const extractedEsql = [];
|
||||
let index = 0;
|
||||
while (index < content.length) {
|
||||
const start = content.indexOf('```esql', index)
|
||||
const start = content.indexOf('```esql', index);
|
||||
if (start === -1) {
|
||||
break
|
||||
break;
|
||||
}
|
||||
const end = content.indexOf('```', start + 7)
|
||||
const end = content.indexOf('```', start + 7);
|
||||
if (end === -1) {
|
||||
break
|
||||
break;
|
||||
}
|
||||
extractedEsql.push(content.slice(start + 7, end))
|
||||
index = end + 3
|
||||
extractedEsql.push(content.slice(start + 7, end));
|
||||
index = end + 3;
|
||||
}
|
||||
return extractedEsql
|
||||
}
|
||||
return extractedEsql;
|
||||
};
|
||||
|
|
|
@ -1,3 +1,10 @@
|
|||
export const NL_TO_ESQL_AGENT_NODE = "nl_to_esql_agent";
|
||||
export const TOOLS_NODE = "tools";
|
||||
export const ESQL_VALIDATOR_NODE = "esql_validator";
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export const NL_TO_ESQL_AGENT_NODE = 'nl_to_esql_agent';
|
||||
export const TOOLS_NODE = 'tools';
|
||||
export const ESQL_VALIDATOR_NODE = 'esql_validator';
|
||||
|
|
|
@ -1,68 +1,71 @@
|
|||
import { StructuredToolInterface } from "@langchain/core/tools";
|
||||
import { getIndexNamesTool } from "./index_names_tool";
|
||||
import { ElasticsearchClient, KibanaRequest } from "@kbn/core/server";
|
||||
import { ToolNode } from "@langchain/langgraph/prebuilt";
|
||||
import { END, START, StateGraph } from "@langchain/langgraph";
|
||||
import { EsqlSelfHealingAnnotation } from "./state";
|
||||
import { ESQL_VALIDATOR_NODE, NL_TO_ESQL_AGENT_NODE, TOOLS_NODE } from "./constants";
|
||||
import { stepRouter } from "./step_router";
|
||||
import { getNlToEsqlAgent } from "./nl_to_esql_agent";
|
||||
import { InferenceServerStart } from "@kbn/inference-plugin/server";
|
||||
import { Logger } from '@kbn/core/server';
|
||||
import { getValidatorNode } from "./validator";
|
||||
import { getInspectIndexMappingTool } from "./inspect_index_mapping_tool";
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { StructuredToolInterface } from '@langchain/core/tools';
|
||||
import type { ElasticsearchClient, KibanaRequest } from '@kbn/core/server';
|
||||
import { ToolNode } from '@langchain/langgraph/prebuilt';
|
||||
import { END, START, StateGraph } from '@langchain/langgraph';
|
||||
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
|
||||
import type { Logger } from '@kbn/core/server';
|
||||
import { getIndexNamesTool } from './index_names_tool';
|
||||
import { EsqlSelfHealingAnnotation } from './state';
|
||||
import { ESQL_VALIDATOR_NODE, NL_TO_ESQL_AGENT_NODE, TOOLS_NODE } from './constants';
|
||||
import { stepRouter } from './step_router';
|
||||
import { getNlToEsqlAgent } from './nl_to_esql_agent';
|
||||
import { getValidatorNode } from './validator';
|
||||
import { getInspectIndexMappingTool } from './inspect_index_mapping_tool';
|
||||
|
||||
export const getEsqlSelfHealingGraph = ({
|
||||
esClient,
|
||||
connectorId,
|
||||
inference,
|
||||
logger,
|
||||
request,
|
||||
}: {
|
||||
esClient: ElasticsearchClient;
|
||||
connectorId: string;
|
||||
inference: InferenceServerStart;
|
||||
logger: Logger;
|
||||
request: KibanaRequest;
|
||||
}) => {
|
||||
const availableIndexNamesTool = getIndexNamesTool({
|
||||
esClient,
|
||||
});
|
||||
const inspectIndexMappingTool = getInspectIndexMappingTool({
|
||||
esClient,
|
||||
});
|
||||
|
||||
const tools: StructuredToolInterface[] = [availableIndexNamesTool, inspectIndexMappingTool];
|
||||
|
||||
const toolNode = new ToolNode(tools);
|
||||
const nlToEsqlAgentNode = getNlToEsqlAgent({
|
||||
connectorId,
|
||||
inference,
|
||||
logger,
|
||||
request,
|
||||
}: {
|
||||
esClient: ElasticsearchClient
|
||||
connectorId: string,
|
||||
inference: InferenceServerStart,
|
||||
logger: Logger,
|
||||
request: KibanaRequest,
|
||||
}
|
||||
) => {
|
||||
tools,
|
||||
});
|
||||
const validatorNode = getValidatorNode({
|
||||
esClient,
|
||||
});
|
||||
|
||||
const availableIndexNamesTool = getIndexNamesTool({
|
||||
esClient
|
||||
const graph = new StateGraph(EsqlSelfHealingAnnotation)
|
||||
.addNode(NL_TO_ESQL_AGENT_NODE, nlToEsqlAgentNode)
|
||||
.addNode(TOOLS_NODE, toolNode)
|
||||
.addNode(ESQL_VALIDATOR_NODE, validatorNode, {
|
||||
ends: [END, NL_TO_ESQL_AGENT_NODE],
|
||||
})
|
||||
const inspectIndexMappingTool = getInspectIndexMappingTool({
|
||||
esClient
|
||||
.addEdge(START, NL_TO_ESQL_AGENT_NODE)
|
||||
.addEdge(TOOLS_NODE, NL_TO_ESQL_AGENT_NODE)
|
||||
.addConditionalEdges(NL_TO_ESQL_AGENT_NODE, stepRouter, {
|
||||
[TOOLS_NODE]: TOOLS_NODE,
|
||||
[ESQL_VALIDATOR_NODE]: ESQL_VALIDATOR_NODE,
|
||||
})
|
||||
.compile();
|
||||
|
||||
const tools: StructuredToolInterface[] = [
|
||||
availableIndexNamesTool,
|
||||
inspectIndexMappingTool
|
||||
]
|
||||
|
||||
const toolNode = new ToolNode(tools)
|
||||
const nlToEsqlAgentNode = getNlToEsqlAgent({
|
||||
connectorId,
|
||||
inference,
|
||||
logger,
|
||||
request,
|
||||
tools,
|
||||
})
|
||||
const validatorNode = getValidatorNode({
|
||||
esClient
|
||||
})
|
||||
|
||||
const graph = new StateGraph(EsqlSelfHealingAnnotation)
|
||||
.addNode(NL_TO_ESQL_AGENT_NODE, nlToEsqlAgentNode)
|
||||
.addNode(TOOLS_NODE, toolNode)
|
||||
.addNode(ESQL_VALIDATOR_NODE, validatorNode, {
|
||||
ends: [END, NL_TO_ESQL_AGENT_NODE]
|
||||
})
|
||||
.addEdge(START, NL_TO_ESQL_AGENT_NODE)
|
||||
.addEdge(TOOLS_NODE, NL_TO_ESQL_AGENT_NODE)
|
||||
.addConditionalEdges(NL_TO_ESQL_AGENT_NODE, stepRouter, {
|
||||
[TOOLS_NODE]: TOOLS_NODE,
|
||||
[ESQL_VALIDATOR_NODE]: ESQL_VALIDATOR_NODE
|
||||
}).compile()
|
||||
|
||||
return graph
|
||||
}
|
||||
return graph;
|
||||
};
|
||||
|
|
|
@ -1,27 +1,39 @@
|
|||
import { ElasticsearchClient } from "@kbn/core/server";
|
||||
import { tool } from "@langchain/core/tools";
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { ElasticsearchClient } from '@kbn/core/server';
|
||||
import { tool } from '@langchain/core/tools';
|
||||
|
||||
const toolDetails = {
|
||||
name: "available_index_names",
|
||||
description: "Get the available indices in the elastic search cluster. Use this when there is an unknown index error or you need to get the indeces that can be queried. Using the response select an appropriate index name."
|
||||
}
|
||||
name: 'available_index_names',
|
||||
description:
|
||||
'Get the available indices in the elastic search cluster. Use this when there is an unknown index error or you need to get the indeces that can be queried. Using the response select an appropriate index name.',
|
||||
};
|
||||
|
||||
export const getIndexNamesTool = ({
|
||||
esClient
|
||||
}: {
|
||||
esClient: ElasticsearchClient
|
||||
}) => {
|
||||
return tool(async () => {
|
||||
const indexNames = await esClient.cat.indices({
|
||||
format: 'json'
|
||||
}).then((response) => response
|
||||
export const getIndexNamesTool = ({ esClient }: { esClient: ElasticsearchClient }) => {
|
||||
return tool(
|
||||
async () => {
|
||||
const indexNames = await esClient.cat
|
||||
.indices({
|
||||
format: 'json',
|
||||
})
|
||||
.then((response) =>
|
||||
response
|
||||
.map((index) => index.index)
|
||||
.filter((index) => index!=undefined)
|
||||
.filter((index) => index != undefined)
|
||||
.toSorted()
|
||||
)
|
||||
return `These are the names of the available indeces. To query them, you must use the full index name verbatim.\n\n${indexNames.join('\n')}`
|
||||
}, {
|
||||
name: toolDetails.name,
|
||||
description: toolDetails.description
|
||||
})
|
||||
}
|
||||
);
|
||||
return `These are the names of the available indeces. To query them, you must use the full index name verbatim.\n\n${indexNames.join(
|
||||
'\n'
|
||||
)}`;
|
||||
},
|
||||
{
|
||||
name: toolDetails.name,
|
||||
description: toolDetails.description,
|
||||
}
|
||||
);
|
||||
};
|
||||
|
|
|
@ -1,10 +1,17 @@
|
|||
import { ElasticsearchClient } from "@kbn/core/server"
|
||||
import { tool } from "@langchain/core/tools"
|
||||
import { z } from "zod"
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { ElasticsearchClient } from '@kbn/core/server';
|
||||
import { tool } from '@langchain/core/tools';
|
||||
import { z } from '@kbn/zod';
|
||||
|
||||
const toolDetails = {
|
||||
name: "inspect_index_mapping",
|
||||
description: `Use this tool to inspect an index mapping. The tool with fetch the mappings of the provided indexName and then return the data at the given propertyKey. For example:
|
||||
name: 'inspect_index_mapping',
|
||||
description: `Use this tool to inspect an index mapping. The tool with fetch the mappings of the provided indexName and then return the data at the given propertyKey. For example:
|
||||
|
||||
Example index mapping:
|
||||
\`\`\`
|
||||
|
@ -59,58 +66,61 @@ Output:
|
|||
{
|
||||
"type": "keyword",
|
||||
}
|
||||
\`\`\``
|
||||
}
|
||||
\`\`\``,
|
||||
};
|
||||
|
||||
export const getInspectIndexMappingTool = ({
|
||||
esClient
|
||||
}: {
|
||||
esClient: ElasticsearchClient
|
||||
}) => {
|
||||
return tool(async ({
|
||||
indexName,
|
||||
propertyKey
|
||||
}) => {
|
||||
const indexMapping = await esClient.indices.getMapping({
|
||||
index: indexName
|
||||
})
|
||||
export const getInspectIndexMappingTool = ({ esClient }: { esClient: ElasticsearchClient }) => {
|
||||
return tool(
|
||||
async ({ indexName, propertyKey }) => {
|
||||
const indexMapping = await esClient.indices.getMapping({
|
||||
index: indexName,
|
||||
});
|
||||
|
||||
const entriesAtKey = getEntriesAtKey(indexMapping[indexName], propertyKey.split("."))
|
||||
const result = formatEntriesAtKey(entriesAtKey)
|
||||
const entriesAtKey = getEntriesAtKey(indexMapping[indexName], propertyKey.split('.'));
|
||||
const result = formatEntriesAtKey(entriesAtKey);
|
||||
|
||||
return `Object at ${propertyKey} \n${JSON.stringify(result, null, 2)}`
|
||||
}, {
|
||||
name: toolDetails.name,
|
||||
description: toolDetails.description,
|
||||
schema: z.object({
|
||||
indexName: z.string().describe(`The index name to get the properties of.`),
|
||||
propertyKey: z.string().optional().default("mappings.properties").describe(`The key to get the properties of.`)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
const getEntriesAtKey = (mapping: Record<string, any> | undefined, keys: string[]): Record<string, any> | undefined => {
|
||||
if (mapping === undefined) {
|
||||
return
|
||||
}
|
||||
if (keys.length === 0) {
|
||||
return mapping
|
||||
return `Object at ${propertyKey} \n${JSON.stringify(result, null, 2)}`;
|
||||
},
|
||||
{
|
||||
name: toolDetails.name,
|
||||
description: toolDetails.description,
|
||||
schema: z.object({
|
||||
indexName: z.string().describe(`The index name to get the properties of.`),
|
||||
propertyKey: z
|
||||
.string()
|
||||
.optional()
|
||||
.default('mappings.properties')
|
||||
.describe(`The key to get the properties of.`),
|
||||
}),
|
||||
}
|
||||
);
|
||||
};
|
||||
|
||||
const key = keys.shift()
|
||||
if (key === undefined) {
|
||||
return mapping
|
||||
}
|
||||
const getEntriesAtKey = (
|
||||
mapping: Record<string, any> | undefined,
|
||||
keys: string[]
|
||||
): Record<string, any> | undefined => {
|
||||
if (mapping === undefined) {
|
||||
return;
|
||||
}
|
||||
if (keys.length === 0) {
|
||||
return mapping;
|
||||
}
|
||||
|
||||
return getEntriesAtKey(mapping[key], keys)
|
||||
}
|
||||
const key = keys.shift();
|
||||
if (key === undefined) {
|
||||
return mapping;
|
||||
}
|
||||
|
||||
return getEntriesAtKey(mapping[key], keys);
|
||||
};
|
||||
|
||||
const formatEntriesAtKey = (mapping: Record<string, any> | undefined): Record<string, string> => {
|
||||
if (mapping === undefined) {
|
||||
return {}
|
||||
}
|
||||
return Object.entries(mapping).reduce((acc, [key, value]) => {
|
||||
acc[key] = typeof value === "string" ? value : "Object"
|
||||
return acc
|
||||
}, {} as Record<string, string>)
|
||||
}
|
||||
if (mapping === undefined) {
|
||||
return {};
|
||||
}
|
||||
return Object.entries(mapping).reduce((acc, [key, value]) => {
|
||||
acc[key] = typeof value === 'string' ? value : 'Object';
|
||||
return acc;
|
||||
}, {} as Record<string, string>);
|
||||
};
|
||||
|
|
|
@ -1,180 +1,201 @@
|
|||
import { ElasticsearchClient, KibanaRequest, Logger } from '@kbn/core/server';
|
||||
import { EditorError, parse } from '@kbn/esql-ast';
|
||||
import { InferenceServerStart, naturalLanguageToEsql } from '@kbn/inference-plugin/server';
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { ElasticsearchClient, KibanaRequest, Logger } from '@kbn/core/server';
|
||||
import type { EditorError } from '@kbn/esql-ast';
|
||||
import { parse } from '@kbn/esql-ast';
|
||||
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
|
||||
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
|
||||
import { lastValueFrom } from 'rxjs';
|
||||
import { getEsqlFromContent } from './common';
|
||||
import { isEmpty, pick } from 'lodash';
|
||||
import { z, ZodSchema } from 'zod';
|
||||
import { JsonSchema7Type, zodToJsonSchema } from 'zod-to-json-schema';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { ToolSchema } from '@kbn/inference-common';
|
||||
import { keyword } from '@kbn/fleet-plugin/server/services/epm/elasticsearch/template/mappings';
|
||||
import type { ZodSchema } from '@kbn/zod';
|
||||
import { z } from '@kbn/zod';
|
||||
import { zodToJsonSchema } from 'zod-to-json-schema';
|
||||
import type { ToolSchema } from '@kbn/inference-common';
|
||||
import { getEsqlFromContent } from './common';
|
||||
|
||||
interface NlToEsqlCommand {
|
||||
query: string;
|
||||
query: string;
|
||||
}
|
||||
|
||||
interface NlToEsqlCommandWithError extends NlToEsqlCommand {
|
||||
parsingErrors?: EditorError[];
|
||||
executeError?: unknown;
|
||||
parsingErrors?: EditorError[];
|
||||
executeError?: unknown;
|
||||
}
|
||||
|
||||
export class NaturalLanguageToEsqlValidator {
|
||||
private readonly inference: InferenceServerStart
|
||||
private readonly connectorId: string
|
||||
private readonly logger: Logger
|
||||
private readonly request: KibanaRequest
|
||||
private readonly esClient: ElasticsearchClient
|
||||
private readonly inference: InferenceServerStart;
|
||||
private readonly connectorId: string;
|
||||
private readonly logger: Logger;
|
||||
private readonly request: KibanaRequest;
|
||||
private readonly esClient: ElasticsearchClient;
|
||||
|
||||
constructor(
|
||||
params: {
|
||||
inference: InferenceServerStart,
|
||||
connectorId: string,
|
||||
logger: Logger
|
||||
request: KibanaRequest
|
||||
esClient: ElasticsearchClient
|
||||
}
|
||||
) {
|
||||
this.inference = params.inference;
|
||||
this.connectorId = params.connectorId;
|
||||
this.logger = params.logger;
|
||||
this.request = params.request;
|
||||
this.esClient = params.esClient;
|
||||
constructor(params: {
|
||||
inference: InferenceServerStart;
|
||||
connectorId: string;
|
||||
logger: Logger;
|
||||
request: KibanaRequest;
|
||||
esClient: ElasticsearchClient;
|
||||
}) {
|
||||
this.inference = params.inference;
|
||||
this.connectorId = params.connectorId;
|
||||
this.logger = params.logger;
|
||||
this.request = params.request;
|
||||
this.esClient = params.esClient;
|
||||
}
|
||||
|
||||
private async callNaturalLanguageToEsql(question: string) {
|
||||
return lastValueFrom(
|
||||
naturalLanguageToEsql({
|
||||
client: this.inference.getClient({ request: this.request }),
|
||||
connectorId: this.connectorId,
|
||||
input: question,
|
||||
functionCalling: 'auto',
|
||||
logger: this.logger,
|
||||
tools: {
|
||||
get_available_indecies: {
|
||||
description:
|
||||
'Get the available indecies in the elastic search cluster. Use this when there is an unknown index error.',
|
||||
schema: zodSchemaToInference(
|
||||
z.object({
|
||||
keyword: z.string(),
|
||||
})
|
||||
),
|
||||
},
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
private esqlParsingErrors(esqlQuery: string): NlToEsqlCommand | undefined {
|
||||
const { errors: parsingErrors, root } = parse(esqlQuery);
|
||||
|
||||
console.log(root);
|
||||
|
||||
if (!isEmpty(parsingErrors)) {
|
||||
return {
|
||||
query: esqlQuery,
|
||||
parsingErrors,
|
||||
} as NlToEsqlCommandWithError;
|
||||
}
|
||||
}
|
||||
|
||||
private async testRunQuery(esqlQuery: string) {
|
||||
try {
|
||||
await this.esClient.esql.query({
|
||||
query: esqlQuery,
|
||||
format: 'json',
|
||||
});
|
||||
} catch (e) {
|
||||
return {
|
||||
query: esqlQuery,
|
||||
executeError: e,
|
||||
} as NlToEsqlCommandWithError;
|
||||
}
|
||||
}
|
||||
|
||||
private async getAvailableIndeciesPrompt() {
|
||||
return this.esClient.cat
|
||||
.indices({
|
||||
format: 'json',
|
||||
})
|
||||
.then((response) => {
|
||||
return `The available indecies are\n${response
|
||||
.map((index: any) => index.index)
|
||||
.join('\n')}`;
|
||||
});
|
||||
}
|
||||
|
||||
private async recursivlyGenerateAndValidateEsql(
|
||||
question: string,
|
||||
depth = 0
|
||||
): Promise<Array<string | undefined>> {
|
||||
if (depth > 3) {
|
||||
return [question];
|
||||
}
|
||||
const generateEvent = await this.callNaturalLanguageToEsql(question);
|
||||
console.log('generateEvent');
|
||||
console.log(JSON.stringify(generateEvent));
|
||||
if (!generateEvent.content) {
|
||||
return [`Unable to generate query.\n${question}`];
|
||||
}
|
||||
const queries = getEsqlFromContent(generateEvent.content);
|
||||
if (isEmpty(queries)) {
|
||||
return [generateEvent.content];
|
||||
}
|
||||
|
||||
private async callNaturalLanguageToEsql(question: string) {
|
||||
return lastValueFrom(
|
||||
naturalLanguageToEsql({
|
||||
client: this.inference.getClient({ request: this.request }),
|
||||
connectorId: this.connectorId,
|
||||
input: question,
|
||||
functionCalling: 'auto',
|
||||
logger: this.logger,
|
||||
tools:{
|
||||
"get_available_indecies": {
|
||||
"description": "Get the available indecies in the elastic search cluster. Use this when there is an unknown index error.",
|
||||
"schema": zodSchemaToInference(z.object({
|
||||
keyword: z.string()
|
||||
}))
|
||||
}
|
||||
}
|
||||
})
|
||||
);
|
||||
const results = await Promise.all(
|
||||
queries.map(async (query) => {
|
||||
if (isEmpty(query)) return undefined;
|
||||
|
||||
let errors = this.esqlParsingErrors(query);
|
||||
|
||||
if (this.isNlToEsqlCommandWithError(errors)) {
|
||||
return this.recursivlyGenerateAndValidateEsql(
|
||||
await this.formatEsqlQueryErrorForPrompt(errors),
|
||||
depth + 1
|
||||
);
|
||||
}
|
||||
|
||||
errors = await this.testRunQuery(query);
|
||||
|
||||
if (this.isNlToEsqlCommandWithError(errors)) {
|
||||
return this.recursivlyGenerateAndValidateEsql(
|
||||
await this.formatEsqlQueryErrorForPrompt(errors),
|
||||
depth + 1
|
||||
);
|
||||
}
|
||||
|
||||
return query;
|
||||
})
|
||||
);
|
||||
|
||||
return results.flat().filter((result) => result !== undefined);
|
||||
}
|
||||
|
||||
public async generateEsqlFromNaturalLanguage(question: string) {
|
||||
return this.recursivlyGenerateAndValidateEsql(question);
|
||||
}
|
||||
|
||||
private isNlToEsqlCommandWithError(
|
||||
command: undefined | NlToEsqlCommand | NlToEsqlCommandWithError
|
||||
): command is NlToEsqlCommandWithError {
|
||||
return (
|
||||
(command as undefined | NlToEsqlCommandWithError)?.parsingErrors !== undefined ||
|
||||
(command as undefined | NlToEsqlCommandWithError)?.executeError !== undefined
|
||||
);
|
||||
}
|
||||
|
||||
private async formatEsqlQueryErrorForPrompt(error: NlToEsqlCommand): Promise<string> {
|
||||
if (!this.isNlToEsqlCommandWithError(error)) {
|
||||
throw new Error('Error is not an NlToEsqlCommandWithError');
|
||||
}
|
||||
|
||||
private esqlParsingErrors(esqlQuery: string): NlToEsqlCommand | undefined {
|
||||
const { errors: parsingErrors, root } = parse(esqlQuery)
|
||||
|
||||
console.log(root)
|
||||
|
||||
if (!isEmpty(parsingErrors)) {
|
||||
return {
|
||||
query: esqlQuery,
|
||||
parsingErrors
|
||||
} as NlToEsqlCommandWithError
|
||||
}
|
||||
return
|
||||
let errorString = `The query bellow could not be executed due to the following errors. Try again or reply with debugging instructions\n\`\`\`esql${error.query}\`\`\`\n`;
|
||||
if (error.parsingErrors) {
|
||||
errorString += 'Parsing Errors:\n';
|
||||
error.parsingErrors.forEach((parsingError) => {
|
||||
errorString += `${parsingError.message}\n`;
|
||||
});
|
||||
}
|
||||
|
||||
private async testRunQuery(esqlQuery: string) {
|
||||
try {
|
||||
await this.esClient.esql.query(
|
||||
{
|
||||
query: esqlQuery,
|
||||
format: 'json',
|
||||
},
|
||||
)
|
||||
} catch (e) {
|
||||
return {
|
||||
query: esqlQuery,
|
||||
executeError: e
|
||||
} as NlToEsqlCommandWithError
|
||||
}
|
||||
return
|
||||
if (error.executeError) {
|
||||
errorString += `Execution Errors:\n${(error.executeError as any).message}\n`;
|
||||
}
|
||||
|
||||
private async getAvailableIndeciesPrompt() {
|
||||
return this.esClient.cat.indices({
|
||||
format:'json'
|
||||
}).then((response) => {
|
||||
return `The available indecies are\n${response.map((index: any) => index.index).join('\n')}`
|
||||
})
|
||||
}
|
||||
|
||||
private async recursivlyGenerateAndValidateEsql(question: string, depth = 0): Promise<(string | undefined)[]> {
|
||||
if (depth > 3) {
|
||||
return [question];
|
||||
}
|
||||
const generateEvent = await this.callNaturalLanguageToEsql(question);
|
||||
console.log("generateEvent")
|
||||
console.log(JSON.stringify(generateEvent))
|
||||
if(!generateEvent.content){
|
||||
return [`Unable to generate query.\n${question}`];
|
||||
}
|
||||
const queries = getEsqlFromContent(generateEvent.content);
|
||||
if(isEmpty(queries)){
|
||||
return [generateEvent.content];
|
||||
}
|
||||
|
||||
const results = await Promise.all(
|
||||
queries.map(async (query) => {
|
||||
if (isEmpty(query)) return undefined;
|
||||
|
||||
let errors = this.esqlParsingErrors(query);
|
||||
|
||||
if (this.isNlToEsqlCommandWithError(errors)) {
|
||||
return this.recursivlyGenerateAndValidateEsql(await this.formatEsqlQueryErrorForPrompt(errors), depth + 1);
|
||||
}
|
||||
|
||||
errors = await this.testRunQuery(query);
|
||||
|
||||
if (this.isNlToEsqlCommandWithError(errors)) {
|
||||
return this.recursivlyGenerateAndValidateEsql(await this.formatEsqlQueryErrorForPrompt(errors), depth + 1);
|
||||
}
|
||||
|
||||
return query;
|
||||
})
|
||||
)
|
||||
|
||||
return results.flat().filter((result) => result !== undefined);
|
||||
}
|
||||
|
||||
public async generateEsqlFromNaturalLanguage(question: string) {
|
||||
return this.recursivlyGenerateAndValidateEsql(question)
|
||||
}
|
||||
|
||||
private isNlToEsqlCommandWithError(command: undefined | NlToEsqlCommand | NlToEsqlCommandWithError): command is NlToEsqlCommandWithError {
|
||||
return (command as undefined | NlToEsqlCommandWithError)?.parsingErrors !== undefined || (command as undefined | NlToEsqlCommandWithError)?.executeError !== undefined;
|
||||
}
|
||||
|
||||
private async formatEsqlQueryErrorForPrompt(error: NlToEsqlCommand): Promise<string> {
|
||||
if (!this.isNlToEsqlCommandWithError(error)) {
|
||||
throw new Error('Error is not an NlToEsqlCommandWithError');
|
||||
}
|
||||
|
||||
let errorString = `The query bellow could not be executed due to the following errors. Try again or reply with debugging instructions\n\`\`\`esql${error.query}\`\`\`\n`;
|
||||
if (error.parsingErrors) {
|
||||
errorString += 'Parsing Errors:\n';
|
||||
error.parsingErrors.forEach((parsingError) => {
|
||||
errorString += `${parsingError.message}\n`;
|
||||
});
|
||||
}
|
||||
|
||||
if (error.executeError) {
|
||||
errorString += `Execution Errors:\n${(error.executeError as any).message}\n`;
|
||||
}
|
||||
|
||||
if(false && errorString.includes('Unknown index')) {
|
||||
errorString += await this.getAvailableIndeciesPrompt();
|
||||
}
|
||||
|
||||
console.log(errorString);
|
||||
this.logger.error(errorString);
|
||||
|
||||
return errorString;
|
||||
if (false && errorString.includes('Unknown index')) {
|
||||
errorString += await this.getAvailableIndeciesPrompt();
|
||||
}
|
||||
|
||||
console.log(errorString);
|
||||
this.logger.error(errorString);
|
||||
|
||||
return errorString;
|
||||
}
|
||||
}
|
||||
|
||||
function zodSchemaToInference(schema: ZodSchema): ToolSchema {
|
||||
|
|
|
@ -1,39 +1,53 @@
|
|||
import { lastValueFrom } from "rxjs";
|
||||
import { EsqlSelfHealingAnnotation } from "./state"
|
||||
import { langchainMessageToInferenceMessage, langchainToolsToInferenceTools, nlToEsqlTaskEventToLangchainMessage } from "@kbn/elastic-assistant-common";
|
||||
import { StructuredToolInterface } from "@langchain/core/tools";
|
||||
import { KibanaRequest } from "@kbn/core/server";
|
||||
import { InferenceServerStart, naturalLanguageToEsql } from "@kbn/inference-plugin/server";
|
||||
import { Logger } from '@kbn/core/server';
|
||||
import { ChatCompletionMessageEvent } from "@kbn/inference-common";
|
||||
import { Command } from "@langchain/langgraph";
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { lastValueFrom } from 'rxjs';
|
||||
import {
|
||||
langchainMessageToInferenceMessage,
|
||||
langchainToolsToInferenceTools,
|
||||
nlToEsqlTaskEventToLangchainMessage,
|
||||
} from '@kbn/elastic-assistant-common';
|
||||
import type { StructuredToolInterface } from '@langchain/core/tools';
|
||||
import type { KibanaRequest } from '@kbn/core/server';
|
||||
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
|
||||
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
|
||||
import type { Logger } from '@kbn/core/server';
|
||||
import type { ChatCompletionMessageEvent } from '@kbn/inference-common';
|
||||
import { Command } from '@langchain/langgraph';
|
||||
import type { EsqlSelfHealingAnnotation } from './state';
|
||||
|
||||
export const getNlToEsqlAgent = ({
|
||||
connectorId,
|
||||
inference,
|
||||
logger,
|
||||
request,
|
||||
tools,
|
||||
}:{
|
||||
connectorId: string,
|
||||
inference: InferenceServerStart,
|
||||
logger: Logger,
|
||||
request: KibanaRequest,
|
||||
tools: StructuredToolInterface[]
|
||||
}) => {
|
||||
return async (state: typeof EsqlSelfHealingAnnotation.State) => {
|
||||
const { messages: stateMessages } = state;
|
||||
connectorId,
|
||||
inference,
|
||||
logger,
|
||||
request,
|
||||
tools,
|
||||
}: {
|
||||
connectorId: string;
|
||||
inference: InferenceServerStart;
|
||||
logger: Logger;
|
||||
request: KibanaRequest;
|
||||
tools: StructuredToolInterface[];
|
||||
}) => {
|
||||
return async (state: typeof EsqlSelfHealingAnnotation.State) => {
|
||||
const { messages: stateMessages } = state;
|
||||
|
||||
const inferenceMessages = stateMessages.map(langchainMessageToInferenceMessage)
|
||||
const inferenceMessages = stateMessages.map(langchainMessageToInferenceMessage);
|
||||
|
||||
const result = await lastValueFrom(naturalLanguageToEsql({
|
||||
client: inference.getClient({ request: request }),
|
||||
connectorId: connectorId,
|
||||
functionCalling: 'auto',
|
||||
logger: logger,
|
||||
tools: langchainToolsToInferenceTools(tools),
|
||||
messages: inferenceMessages
|
||||
})) as ChatCompletionMessageEvent
|
||||
const result = (await lastValueFrom(
|
||||
naturalLanguageToEsql({
|
||||
client: inference.getClient({ request }),
|
||||
connectorId,
|
||||
functionCalling: 'auto',
|
||||
logger,
|
||||
tools: langchainToolsToInferenceTools(tools),
|
||||
messages: inferenceMessages,
|
||||
})
|
||||
)) as ChatCompletionMessageEvent;
|
||||
|
||||
return new Command({
|
||||
update:{
|
||||
|
@ -42,4 +56,4 @@ export const getNlToEsqlAgent = ({
|
|||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,9 +8,9 @@
|
|||
import { tool } from '@langchain/core/tools';
|
||||
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
|
||||
import { z } from '@kbn/zod';
|
||||
import { HumanMessage } from '@langchain/core/messages';
|
||||
import { APP_UI_ID } from '../../../../common';
|
||||
import { getPromptSuffixForOssModel } from './common';
|
||||
import { HumanMessage } from '@langchain/core/messages';
|
||||
import { getEsqlSelfHealingGraph } from './graph';
|
||||
|
||||
// select only some properties of AssistantToolParams
|
||||
|
@ -42,7 +42,8 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
|
|||
getTool(params: ESQLToolParams) {
|
||||
if (!this.isSupported(params)) return null;
|
||||
|
||||
const { connectorId, inference, logger, request, isOssModel, esClient } = params as ESQLToolParams;
|
||||
const { connectorId, inference, logger, request, isOssModel, esClient } =
|
||||
params as ESQLToolParams;
|
||||
if (inference == null || connectorId == null) return null;
|
||||
|
||||
const selfHealingGraph = getEsqlSelfHealingGraph({
|
||||
|
@ -51,7 +52,7 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
|
|||
inference,
|
||||
logger,
|
||||
request,
|
||||
})
|
||||
});
|
||||
|
||||
return tool(
|
||||
async ({ question }) => {
|
||||
|
@ -74,5 +75,3 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
|
|||
);
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -14,4 +14,4 @@ export const EsqlSelfHealingAnnotation = Annotation.Root({
|
|||
reducer: (currentValue, newValue) => newValue ?? currentValue,
|
||||
default: () => 30,
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
|
|
@ -1,16 +1,23 @@
|
|||
import { EsqlSelfHealingAnnotation } from "./state";
|
||||
import { ESQL_VALIDATOR_NODE, TOOLS_NODE } from "./constants";
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { EsqlSelfHealingAnnotation } from './state';
|
||||
import { ESQL_VALIDATOR_NODE, TOOLS_NODE } from './constants';
|
||||
|
||||
export const stepRouter = (state: typeof EsqlSelfHealingAnnotation.State): string => {
|
||||
const { messages } = state;
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
if (
|
||||
'tool_calls' in lastMessage &&
|
||||
Array.isArray(lastMessage.tool_calls) &&
|
||||
lastMessage.tool_calls?.length
|
||||
) {
|
||||
return TOOLS_NODE;
|
||||
}
|
||||
const { messages } = state;
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
if (
|
||||
'tool_calls' in lastMessage &&
|
||||
Array.isArray(lastMessage.tool_calls) &&
|
||||
lastMessage.tool_calls?.length
|
||||
) {
|
||||
return TOOLS_NODE;
|
||||
}
|
||||
|
||||
return ESQL_VALIDATOR_NODE
|
||||
}
|
||||
return ESQL_VALIDATOR_NODE;
|
||||
};
|
||||
|
|
|
@ -1,29 +1,34 @@
|
|||
import { ElasticsearchClient } from "@kbn/core/server"
|
||||
import { EsqlSelfHealingAnnotation } from "./state"
|
||||
import { getEsqlFromContent } from "./common"
|
||||
import { Command, END } from "@langchain/langgraph"
|
||||
import { EditorError, parse } from "@kbn/esql-ast"
|
||||
import { isEmpty } from "lodash"
|
||||
import { BaseMessage, HumanMessage } from "@langchain/core/messages"
|
||||
import { NL_TO_ESQL_AGENT_NODE } from "./constants"
|
||||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
type ValidateEsqlResult = {
|
||||
isValid: boolean
|
||||
query: string
|
||||
parsingErrors?: EditorError[]
|
||||
executionError?: any
|
||||
import type { ElasticsearchClient } from '@kbn/core/server';
|
||||
import { Command, END } from '@langchain/langgraph';
|
||||
import type { EditorError } from '@kbn/esql-ast';
|
||||
import { parse } from '@kbn/esql-ast';
|
||||
import { isEmpty } from 'lodash';
|
||||
import type { BaseMessage } from '@langchain/core/messages';
|
||||
import { HumanMessage } from '@langchain/core/messages';
|
||||
import { getEsqlFromContent } from './common';
|
||||
import type { EsqlSelfHealingAnnotation } from './state';
|
||||
import { NL_TO_ESQL_AGENT_NODE } from './constants';
|
||||
|
||||
interface ValidateEsqlResult {
|
||||
isValid: boolean;
|
||||
query: string;
|
||||
parsingErrors?: EditorError[];
|
||||
executionError?: any;
|
||||
}
|
||||
|
||||
export const getValidatorNode = ({
|
||||
esClient
|
||||
}: {
|
||||
esClient: ElasticsearchClient
|
||||
}) => {
|
||||
return async (state: typeof EsqlSelfHealingAnnotation.State) => {
|
||||
const { messages } = state;
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
export const getValidatorNode = ({ esClient }: { esClient: ElasticsearchClient }) => {
|
||||
return async (state: typeof EsqlSelfHealingAnnotation.State) => {
|
||||
const { messages } = state;
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
|
||||
const generatedQueries = getEsqlFromContent(lastMessage.content as string)
|
||||
const generatedQueries = getEsqlFromContent(lastMessage.content as string);
|
||||
|
||||
if (!generatedQueries.length) {
|
||||
return new Command({
|
||||
|
@ -58,7 +63,6 @@ export const getValidatorNode = ({
|
|||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const lastMessageWithErrorReport = (message: string, validateEsqlResults: ValidateEsqlResult[]): BaseMessage => {
|
||||
let messageWithErrorReport = message
|
||||
|
@ -68,10 +72,49 @@ const lastMessageWithErrorReport = (message: string, validateEsqlResults: Valida
|
|||
messageWithErrorReport = `${messageWithErrorReport.slice(0, index + 3)}\n${errorMessage}\n${messageWithErrorReport.slice(index + 3)}`
|
||||
})
|
||||
|
||||
return new HumanMessage({
|
||||
content: messageWithErrorReport
|
||||
})
|
||||
}
|
||||
const containsInvalidQueries = validateEsqlResults.some((result) => !result.isValid);
|
||||
|
||||
if (containsInvalidQueries) {
|
||||
return new Command({
|
||||
update: {
|
||||
messages: [
|
||||
lastMessageWithErrorReport(lastMessage.content as string, validateEsqlResults),
|
||||
],
|
||||
},
|
||||
goto: NL_TO_ESQL_AGENT_NODE,
|
||||
});
|
||||
}
|
||||
|
||||
return new Command({
|
||||
goto: END,
|
||||
update: {
|
||||
messages: `${lastMessage.content} \nAll of the queries have been validated and do not need to modified futher.`,
|
||||
},
|
||||
});
|
||||
};
|
||||
};
|
||||
|
||||
const lastMessageWithErrorReport = (
|
||||
message: string,
|
||||
validateEsqlResults: ValidateEsqlResult[]
|
||||
): BaseMessage => {
|
||||
let messageWithErrorReport = message;
|
||||
validateEsqlResults.reverse().forEach((validateEsqlResult) => {
|
||||
const index = messageWithErrorReport.indexOf(
|
||||
'```',
|
||||
messageWithErrorReport.indexOf(validateEsqlResult.query)
|
||||
);
|
||||
const errorMessage = formatValidateEsqlResultToHumanReadable(validateEsqlResult);
|
||||
messageWithErrorReport = `${messageWithErrorReport.slice(
|
||||
0,
|
||||
index + 3
|
||||
)}\n${errorMessage}\n${messageWithErrorReport.slice(index + 3)}`;
|
||||
});
|
||||
|
||||
return new HumanMessage({
|
||||
content: messageWithErrorReport,
|
||||
});
|
||||
};
|
||||
|
||||
const formatValidateEsqlResultToHumanReadable = (validateEsqlResult: ValidateEsqlResult) => {
|
||||
if (validateEsqlResult.isValid) {
|
||||
|
@ -112,12 +155,38 @@ const validateEsql = async (esClient: ElasticsearchClient, query: string): Promi
|
|||
}
|
||||
}
|
||||
|
||||
const validateEsql = async (
|
||||
esClient: ElasticsearchClient,
|
||||
query: string
|
||||
): Promise<ValidateEsqlResult> => {
|
||||
const { errors: parsingErrors } = parse(query);
|
||||
if (!isEmpty(parsingErrors)) {
|
||||
return {
|
||||
isValid: true,
|
||||
query
|
||||
}
|
||||
}
|
||||
isValid: false,
|
||||
query,
|
||||
parsingErrors,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
await esClient.esql.query({
|
||||
query,
|
||||
format: 'json',
|
||||
});
|
||||
} catch (executionError) {
|
||||
return {
|
||||
isValid: false,
|
||||
query,
|
||||
executionError,
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: true,
|
||||
query,
|
||||
};
|
||||
};
|
||||
|
||||
const extractErrorMessage = (error: any): string => {
|
||||
return error.message || `Unknown error`
|
||||
}
|
||||
return error.message || `Unknown error`;
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue