mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
attempt 1
This commit is contained in:
parent
b79e33db32
commit
ad3dce1179
4 changed files with 201 additions and 59 deletions
|
@ -29,7 +29,7 @@ export const localToolPrompts: Prompt[] = [
|
|||
- convert queries from another language to ES|QL
|
||||
- asks general questions about ES|QL
|
||||
|
||||
ALWAYS use this tool to generate ES|QL queries or explain anything about the ES|QL query language rather than coming up with your own answer.`,
|
||||
ALWAYS use this tool to generate ES|QL queries or explain anything about the ES|QL query language rather than coming up with your own answer. The tool will validate the query.`,
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
describe("natural_lanagauge_to_esql_validator", ()=>{
|
||||
it("test1", ()=>{
|
||||
|
||||
})
|
||||
})
|
|
@ -0,0 +1,182 @@
|
|||
import { ElasticsearchClient, KibanaRequest, Logger } from '@kbn/core/server';
|
||||
import { EditorError, parse } from '@kbn/esql-ast';
|
||||
import { InferenceServerStart, naturalLanguageToEsql } from '@kbn/inference-plugin/server';
|
||||
import { lastValueFrom } from 'rxjs';
|
||||
import { getEsqlFromContent } from './common';
|
||||
import { isEmpty, pick } from 'lodash';
|
||||
import { z, ZodSchema } from 'zod';
|
||||
import { JsonSchema7Type, zodToJsonSchema } from 'zod-to-json-schema';
|
||||
import { OpenAPIV3 } from 'openapi-types';
|
||||
import { ToolSchema } from '@kbn/inference-common';
|
||||
import { keyword } from '@kbn/fleet-plugin/server/services/epm/elasticsearch/template/mappings';
|
||||
|
||||
interface NlToEsqlCommand {
|
||||
query: string;
|
||||
}
|
||||
|
||||
interface NlToEsqlCommandWithError extends NlToEsqlCommand {
|
||||
parsingErrors?: EditorError[];
|
||||
executeError?: unknown;
|
||||
}
|
||||
|
||||
export class NaturalLanguageToEsqlValidator {
|
||||
private readonly inference: InferenceServerStart
|
||||
private readonly connectorId: string
|
||||
private readonly logger: Logger
|
||||
private readonly request: KibanaRequest
|
||||
private readonly esClient: ElasticsearchClient
|
||||
|
||||
constructor(
|
||||
params: {
|
||||
inference: InferenceServerStart,
|
||||
connectorId: string,
|
||||
logger: Logger
|
||||
request: KibanaRequest
|
||||
esClient: ElasticsearchClient
|
||||
}
|
||||
) {
|
||||
this.inference = params.inference;
|
||||
this.connectorId = params.connectorId;
|
||||
this.logger = params.logger;
|
||||
this.request = params.request;
|
||||
this.esClient = params.esClient;
|
||||
}
|
||||
|
||||
private async callNaturalLanguageToEsql(question: string) {
|
||||
return lastValueFrom(
|
||||
naturalLanguageToEsql({
|
||||
client: this.inference.getClient({ request: this.request }),
|
||||
connectorId: this.connectorId,
|
||||
input: question,
|
||||
functionCalling: 'auto',
|
||||
logger: this.logger,
|
||||
tools:{
|
||||
"get_available_indecies": {
|
||||
"description": "Get the available indecies in the elastic search cluster. Use this when there is an unknown index error.",
|
||||
"schema": zodSchemaToInference(z.object({
|
||||
keyword: z.string()
|
||||
}))
|
||||
}
|
||||
}
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
private esqlParsingErrors(esqlQuery: string): NlToEsqlCommand | undefined {
|
||||
const { errors: parsingErrors, root } = parse(esqlQuery)
|
||||
|
||||
console.log(root)
|
||||
|
||||
if (!isEmpty(parsingErrors)) {
|
||||
return {
|
||||
query: esqlQuery,
|
||||
parsingErrors
|
||||
} as NlToEsqlCommandWithError
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
private async testRunQuery(esqlQuery: string) {
|
||||
try {
|
||||
await this.esClient.esql.query(
|
||||
{
|
||||
query: esqlQuery,
|
||||
format: 'json',
|
||||
},
|
||||
)
|
||||
} catch (e) {
|
||||
return {
|
||||
query: esqlQuery,
|
||||
executeError: e
|
||||
} as NlToEsqlCommandWithError
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
private async getAvailableIndeciesPrompt() {
|
||||
return this.esClient.cat.indices({
|
||||
format:'json'
|
||||
}).then((response) => {
|
||||
return `The available indecies are\n${response.map((index: any) => index.index).join('\n')}`
|
||||
})
|
||||
}
|
||||
|
||||
private async recursivlyGenerateAndValidateEsql(question: string, depth = 0): Promise<(string | undefined)[]> {
|
||||
if (depth > 3) {
|
||||
return [question];
|
||||
}
|
||||
const generateEvent = await this.callNaturalLanguageToEsql(question);
|
||||
console.log("generateEvent")
|
||||
console.log(JSON.stringify(generateEvent))
|
||||
if(!generateEvent.content){
|
||||
return [`Unable to generate query.\n${question}`];
|
||||
}
|
||||
const queries = getEsqlFromContent(generateEvent.content);
|
||||
if(isEmpty(queries)){
|
||||
return [generateEvent.content];
|
||||
}
|
||||
|
||||
const results = await Promise.all(
|
||||
queries.map(async (query) => {
|
||||
if (isEmpty(query)) return undefined;
|
||||
|
||||
let errors = this.esqlParsingErrors(query);
|
||||
|
||||
if (this.isNlToEsqlCommandWithError(errors)) {
|
||||
return this.recursivlyGenerateAndValidateEsql(await this.formatEsqlQueryErrorForPrompt(errors), depth + 1);
|
||||
}
|
||||
|
||||
errors = await this.testRunQuery(query);
|
||||
|
||||
if (this.isNlToEsqlCommandWithError(errors)) {
|
||||
return this.recursivlyGenerateAndValidateEsql(await this.formatEsqlQueryErrorForPrompt(errors), depth + 1);
|
||||
}
|
||||
|
||||
return query;
|
||||
})
|
||||
)
|
||||
|
||||
return results.flat().filter((result) => result !== undefined);
|
||||
}
|
||||
|
||||
public async generateEsqlFromNaturalLanguage(question: string) {
|
||||
return this.recursivlyGenerateAndValidateEsql(question)
|
||||
}
|
||||
|
||||
private isNlToEsqlCommandWithError(command: undefined | NlToEsqlCommand | NlToEsqlCommandWithError): command is NlToEsqlCommandWithError {
|
||||
return (command as undefined | NlToEsqlCommandWithError)?.parsingErrors !== undefined || (command as undefined | NlToEsqlCommandWithError)?.executeError !== undefined;
|
||||
}
|
||||
|
||||
private async formatEsqlQueryErrorForPrompt(error: NlToEsqlCommand): Promise<string> {
|
||||
if (!this.isNlToEsqlCommandWithError(error)) {
|
||||
throw new Error('Error is not an NlToEsqlCommandWithError');
|
||||
}
|
||||
|
||||
let errorString = `The query bellow could not be executed due to the following errors. Try again or reply with debugging instructions\n\`\`\`esql${error.query}\`\`\`\n`;
|
||||
if (error.parsingErrors) {
|
||||
errorString += 'Parsing Errors:\n';
|
||||
error.parsingErrors.forEach((parsingError) => {
|
||||
errorString += `${parsingError.message}\n`;
|
||||
});
|
||||
}
|
||||
|
||||
if (error.executeError) {
|
||||
errorString += `Execution Errors:\n${(error.executeError as any).message}\n`;
|
||||
}
|
||||
|
||||
if(false && errorString.includes('Unknown index')) {
|
||||
errorString += await this.getAvailableIndeciesPrompt();
|
||||
}
|
||||
|
||||
console.log(errorString);
|
||||
this.logger.error(errorString);
|
||||
|
||||
return errorString;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
function zodSchemaToInference(schema: ZodSchema): ToolSchema {
|
||||
return pick(zodToJsonSchema(schema), ['type', 'properties', 'required']) as ToolSchema;
|
||||
}
|
|
@ -14,9 +14,10 @@ import { APP_UI_ID } from '../../../../common';
|
|||
import { getEsqlFromContent, getPromptSuffixForOssModel } from './common';
|
||||
import { NlToEsqlTaskEvent } from '@kbn/inference-plugin/server/tasks/nl_to_esql';
|
||||
import { ToolOptions } from '@kbn/inference-common';
|
||||
import { parseEsqlQuery } from '@kbn/securitysolution-utils';
|
||||
import { parse } from '@kbn/esql-ast';
|
||||
import isEmpty from 'lodash/isEmpty';
|
||||
import { ElasticsearchClient } from '@kbn/core/server';
|
||||
import { NaturalLanguageToEsqlValidator } from './natual_language_to_esql_validator';
|
||||
|
||||
// select only some properties of AssistantToolParams
|
||||
export type ESQLToolParams = AssistantToolParams;
|
||||
|
@ -34,7 +35,7 @@ const toolDetails = {
|
|||
- convert queries from another language to ES|QL
|
||||
- asks general questions about ES|QL
|
||||
|
||||
ALWAYS use this tool to generate ES|QL queries or explain anything about the ES|QL query language rather than coming up with your own answer.`,
|
||||
ALWAYS use this tool to generate ES|QL queries or explain anything about the ES|QL query language rather than coming up with your own answer. The tool will validate the query.`,
|
||||
};
|
||||
|
||||
export const NL_TO_ESQL_TOOL: AssistantTool = {
|
||||
|
@ -47,30 +48,23 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
|
|||
getTool(params: ESQLToolParams) {
|
||||
if (!this.isSupported(params)) return null;
|
||||
|
||||
const { connectorId, inference, logger, request, isOssModel } = params as ESQLToolParams;
|
||||
const { connectorId, inference, logger, request, isOssModel, esClient } = params as ESQLToolParams;
|
||||
if (inference == null || connectorId == null) return null;
|
||||
|
||||
const callNaturalLanguageToEsql = async (question: string) => {
|
||||
return lastValueFrom(
|
||||
naturalLanguageToEsql({
|
||||
client: inference.getClient({ request }),
|
||||
connectorId,
|
||||
input: question,
|
||||
functionCalling: 'auto',
|
||||
logger,
|
||||
})
|
||||
);
|
||||
};
|
||||
const naturalLanguageToEsqlValidator = new NaturalLanguageToEsqlValidator({
|
||||
inference,
|
||||
connectorId,
|
||||
logger,
|
||||
request,
|
||||
esClient,
|
||||
})
|
||||
|
||||
return tool(
|
||||
async (input) => {
|
||||
|
||||
const answer = esqlValidator(callNaturalLanguageToEsql)({question: input.question});
|
||||
|
||||
/* const generateEvent = await callNaturalLanguageToEsql(input.question);
|
||||
|
||||
const answer = generateEvent.content ?? 'An error occurred in the tool'; */
|
||||
const answer = await naturalLanguageToEsqlValidator.generateEsqlFromNaturalLanguage(input.question);
|
||||
|
||||
console.log(`Received response from NL to ESQL tool: ${answer}`)
|
||||
logger.debug(`Received response from NL to ESQL tool: ${answer}`);
|
||||
return answer;
|
||||
},
|
||||
|
@ -88,43 +82,4 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
|
|||
},
|
||||
};
|
||||
|
||||
type NatualLanguageToEsqlFunction = (question: string) => Promise<NlToEsqlTaskEvent<ToolOptions<string>>>
|
||||
const maxDepth = 3;
|
||||
|
||||
const esqlValidator = (func: NatualLanguageToEsqlFunction) => {
|
||||
return async (input: { question: string }) => {
|
||||
|
||||
const helper = async (question: string, depth = 0): Promise<(string | undefined)[]> => {
|
||||
console.log(question);
|
||||
if (depth >= maxDepth) {
|
||||
return [`Unable to generate a valid query for the given question: "${question}"`];
|
||||
}
|
||||
const generateEvent = await func(question);
|
||||
const queries = getEsqlFromContent(generateEvent.content);
|
||||
|
||||
const results = await Promise.all(
|
||||
queries.map(async (query) => {
|
||||
query = query+"."
|
||||
if (isEmpty(query)) return undefined;
|
||||
|
||||
const { errors } = parse(query);
|
||||
|
||||
if (!isEmpty(errors)) {
|
||||
const errorString = errors.map((e) => e.message).join("\n");
|
||||
const retryString = `The following query has some syntax errors\n\n"${query}"\n\n${errorString}\n\nPlease try again.`;
|
||||
return helper(retryString, depth + 1);
|
||||
}
|
||||
|
||||
return query;
|
||||
})
|
||||
);
|
||||
|
||||
return results.flat();
|
||||
};
|
||||
|
||||
return helper(input.question);
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue