mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 01:38:56 -04:00
prevent too many ESQl generation attempts
This commit is contained in:
parent
c76a4469ae
commit
58e6e1eded
4 changed files with 31 additions and 11 deletions
|
@ -37,7 +37,8 @@ export const getNlToEsqlAgent = ({
|
|||
|
||||
return new Command({
|
||||
update:{
|
||||
messages : [nlToEsqlTaskEventToLangchainMessage(result)]
|
||||
messages : [nlToEsqlTaskEventToLangchainMessage(result)],
|
||||
maximumLLMCalls: state.maximumLLMCalls - 1
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -56,7 +56,7 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
|
|||
return tool(
|
||||
async ({ question }) => {
|
||||
const humanMessage = new HumanMessage({ content: question })
|
||||
const result = await selfHealingGraph.invoke({ messages: [humanMessage] })
|
||||
const result = await selfHealingGraph.invoke({ messages: [humanMessage]})
|
||||
const { messages } = result
|
||||
const lastMessage = messages[messages.length - 1]
|
||||
return lastMessage.content
|
||||
|
|
|
@ -1,3 +1,17 @@
|
|||
import { MessagesAnnotation } from "@langchain/langgraph";
|
||||
import { BaseMessage } from "@langchain/core/messages";
|
||||
import { Annotation, messagesStateReducer } from "@langchain/langgraph";
|
||||
|
||||
export const EsqlSelfHealingAnnotation = MessagesAnnotation
|
||||
export const EsqlSelfHealingAnnotation = Annotation.Root({
|
||||
messages: Annotation<BaseMessage[]>({
|
||||
reducer: messagesStateReducer,
|
||||
default: () => [],
|
||||
}),
|
||||
maximumValidationAttempts: Annotation<number>({
|
||||
reducer: (currentValue, newValue) => newValue ?? currentValue,
|
||||
default: () => 10,
|
||||
}),
|
||||
maximumLLMCalls: Annotation<number>({
|
||||
reducer: (currentValue, newValue) => newValue ?? currentValue,
|
||||
default: () => 30,
|
||||
}),
|
||||
});
|
|
@ -27,7 +27,10 @@ export const getValidatorNode = ({
|
|||
|
||||
if (!generatedQueries.length) {
|
||||
return new Command({
|
||||
goto: END
|
||||
goto: END,
|
||||
update: {
|
||||
maximumValidationAttempts: state.maximumValidationAttempts - 1,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -40,16 +43,18 @@ export const getValidatorNode = ({
|
|||
if (containsInvalidQueries) {
|
||||
return new Command({
|
||||
update: {
|
||||
messages: [lastMessageWithErrorReport(lastMessage.content as string, validateEsqlResults)]
|
||||
messages: [lastMessageWithErrorReport(lastMessage.content as string, validateEsqlResults)],
|
||||
maximumValidationAttempts: state.maximumValidationAttempts - 1
|
||||
},
|
||||
goto: NL_TO_ESQL_AGENT_NODE
|
||||
goto: state.maximumValidationAttempts <= 0 || state.maximumLLMCalls <= 0 ? END : NL_TO_ESQL_AGENT_NODE
|
||||
})
|
||||
}
|
||||
|
||||
return new Command({
|
||||
goto: END,
|
||||
update:{
|
||||
messages: `${lastMessage.content} \nAll of the queries have been validated and do not need to modified futher.`
|
||||
update: {
|
||||
messages: `${lastMessage.content} \nAll of the queries have been validated and do not need to modified futher.`,
|
||||
maximumValidationAttempts: state.maximumValidationAttempts - 1
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -60,7 +65,7 @@ const lastMessageWithErrorReport = (message: string, validateEsqlResults: Valida
|
|||
validateEsqlResults.reverse().forEach((validateEsqlResult) => {
|
||||
const index = messageWithErrorReport.indexOf('```', messageWithErrorReport.indexOf(validateEsqlResult.query))
|
||||
const errorMessage = formatValidateEsqlResultToHumanReadable(validateEsqlResult)
|
||||
messageWithErrorReport = `${messageWithErrorReport.slice(0, index+3)}\n${errorMessage}\n${messageWithErrorReport.slice(index+3)}`
|
||||
messageWithErrorReport = `${messageWithErrorReport.slice(0, index + 3)}\n${errorMessage}\n${messageWithErrorReport.slice(index + 3)}`
|
||||
})
|
||||
|
||||
return new HumanMessage({
|
||||
|
@ -95,7 +100,7 @@ const validateEsql = async (esClient: ElasticsearchClient, query: string): Promi
|
|||
try {
|
||||
await esClient.esql.query(
|
||||
{
|
||||
query: query,
|
||||
query: `${query}\n| LIMIT 0`, // Add a LIMIT 0 to minimize the risk of executing a costly query
|
||||
format: 'json',
|
||||
},
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue