prevent too many ESQl generation attempts

This commit is contained in:
Kenneth Kreindler 2025-03-07 13:50:36 +00:00
parent c76a4469ae
commit 58e6e1eded
No known key found for this signature in database
GPG key ID: 429CB8689E46A00B
4 changed files with 31 additions and 11 deletions

View file

@ -37,7 +37,8 @@ export const getNlToEsqlAgent = ({
return new Command({
update:{
messages : [nlToEsqlTaskEventToLangchainMessage(result)]
messages : [nlToEsqlTaskEventToLangchainMessage(result)],
maximumLLMCalls: state.maximumLLMCalls - 1
}
})
}

View file

@ -56,7 +56,7 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
return tool(
async ({ question }) => {
const humanMessage = new HumanMessage({ content: question })
const result = await selfHealingGraph.invoke({ messages: [humanMessage] })
const result = await selfHealingGraph.invoke({ messages: [humanMessage]})
const { messages } = result
const lastMessage = messages[messages.length - 1]
return lastMessage.content

View file

@ -1,3 +1,17 @@
import { MessagesAnnotation } from "@langchain/langgraph";
import { BaseMessage } from "@langchain/core/messages";
import { Annotation, messagesStateReducer } from "@langchain/langgraph";
export const EsqlSelfHealingAnnotation = MessagesAnnotation
export const EsqlSelfHealingAnnotation = Annotation.Root({
messages: Annotation<BaseMessage[]>({
reducer: messagesStateReducer,
default: () => [],
}),
maximumValidationAttempts: Annotation<number>({
reducer: (currentValue, newValue) => newValue ?? currentValue,
default: () => 10,
}),
maximumLLMCalls: Annotation<number>({
reducer: (currentValue, newValue) => newValue ?? currentValue,
default: () => 30,
}),
});

View file

@ -27,7 +27,10 @@ export const getValidatorNode = ({
if (!generatedQueries.length) {
return new Command({
goto: END
goto: END,
update: {
maximumValidationAttempts: state.maximumValidationAttempts - 1,
}
})
}
@ -40,16 +43,18 @@ export const getValidatorNode = ({
if (containsInvalidQueries) {
return new Command({
update: {
messages: [lastMessageWithErrorReport(lastMessage.content as string, validateEsqlResults)]
messages: [lastMessageWithErrorReport(lastMessage.content as string, validateEsqlResults)],
maximumValidationAttempts: state.maximumValidationAttempts - 1
},
goto: NL_TO_ESQL_AGENT_NODE
goto: state.maximumValidationAttempts <= 0 || state.maximumLLMCalls <= 0 ? END : NL_TO_ESQL_AGENT_NODE
})
}
return new Command({
goto: END,
update:{
messages: `${lastMessage.content} \nAll of the queries have been validated and do not need to modified futher.`
update: {
messages: `${lastMessage.content} \nAll of the queries have been validated and do not need to modified futher.`,
maximumValidationAttempts: state.maximumValidationAttempts - 1
}
})
}
@ -60,7 +65,7 @@ const lastMessageWithErrorReport = (message: string, validateEsqlResults: Valida
validateEsqlResults.reverse().forEach((validateEsqlResult) => {
const index = messageWithErrorReport.indexOf('```', messageWithErrorReport.indexOf(validateEsqlResult.query))
const errorMessage = formatValidateEsqlResultToHumanReadable(validateEsqlResult)
messageWithErrorReport = `${messageWithErrorReport.slice(0, index+3)}\n${errorMessage}\n${messageWithErrorReport.slice(index+3)}`
messageWithErrorReport = `${messageWithErrorReport.slice(0, index + 3)}\n${errorMessage}\n${messageWithErrorReport.slice(index + 3)}`
})
return new HumanMessage({
@ -95,7 +100,7 @@ const validateEsql = async (esClient: ElasticsearchClient, query: string): Promi
try {
await esClient.esql.query(
{
query: query,
query: `${query}\n| LIMIT 0`, // Add a LIMIT 0 to minimize the risk of executing a costly query
format: 'json',
},
)