This commit is contained in:
Kenneth Kreindler 2025-03-04 20:25:24 +00:00
parent 02b9f8f249
commit b79e33db32
No known key found for this signature in database
GPG key ID: 429CB8689E46A00B
3 changed files with 81 additions and 3 deletions

View file

@ -0,0 +1,10 @@
import { getEsqlFromContent } from "./common"
describe("common", () => {
it.each([
["```esqlhelloworld```", ['helloworld']],
["```esqlhelloworld``````esqlhelloworld```", ['helloworld','helloworld']]
])('should add %s and %s', (input: string, expectedResult: string[]) => {
expect(getEsqlFromContent(input)).toEqual(expectedResult)
})
})

View file

@ -13,3 +13,21 @@ export const getPromptSuffixForOssModel = (toolName: string) => `
The ES|QL query should ALWAYS be wrapped in triple backticks ("\`\`\`esql"). Add a new line character right before the triple backticks.
It is important that ES|QL query is preceeded by a new line.`;
export const getEsqlFromContent = (content: string) : string[] => {
const extractedEsql = []
let index = 0
while (index < content.length) {
const start = content.indexOf('```esql', index)
if (start === -1) {
break
}
const end = content.indexOf('```', start + 7)
if (end === -1) {
break
}
extractedEsql.push(content.slice(start + 7, end))
index = end + 3
}
return extractedEsql
}

View file

@ -11,7 +11,12 @@ import { lastValueFrom } from 'rxjs';
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
import { z } from '@kbn/zod';
import { APP_UI_ID } from '../../../../common';
import { getPromptSuffixForOssModel } from './common';
import { getEsqlFromContent, getPromptSuffixForOssModel } from './common';
import { NlToEsqlTaskEvent } from '@kbn/inference-plugin/server/tasks/nl_to_esql';
import { ToolOptions } from '@kbn/inference-common';
import { parseEsqlQuery } from '@kbn/securitysolution-utils';
import { parse } from '@kbn/esql-ast';
import isEmpty from 'lodash/isEmpty';
// select only some properties of AssistantToolParams
export type ESQLToolParams = AssistantToolParams;
@ -59,8 +64,12 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
return tool(
async (input) => {
const generateEvent = await callNaturalLanguageToEsql(input.question);
const answer = generateEvent.content ?? 'An error occurred in the tool';
const answer = esqlValidator(callNaturalLanguageToEsql)({question: input.question});
/* const generateEvent = await callNaturalLanguageToEsql(input.question);
const answer = generateEvent.content ?? 'An error occurred in the tool'; */
logger.debug(`Received response from NL to ESQL tool: ${answer}`);
return answer;
@ -78,3 +87,44 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
);
},
};
type NatualLanguageToEsqlFunction = (question: string) => Promise<NlToEsqlTaskEvent<ToolOptions<string>>>
const maxDepth = 3;
const esqlValidator = (func: NatualLanguageToEsqlFunction) => {
return async (input: { question: string }) => {
const helper = async (question: string, depth = 0): Promise<(string | undefined)[]> => {
console.log(question);
if (depth >= maxDepth) {
return [`Unable to generate a valid query for the given question: "${question}"`];
}
const generateEvent = await func(question);
const queries = getEsqlFromContent(generateEvent.content);
const results = await Promise.all(
queries.map(async (query) => {
query = query+"."
if (isEmpty(query)) return undefined;
const { errors } = parse(query);
if (!isEmpty(errors)) {
const errorString = errors.map((e) => e.message).join("\n");
const retryString = `The following query has some syntax errors\n\n"${query}"\n\n${errorString}\n\nPlease try again.`;
return helper(retryString, depth + 1);
}
return query;
})
);
return results.flat();
};
return helper(input.question);
};
};