[Security Assistant] AI Assistant - Better Solution for OSS models (#10416) (#194166)

This commit is contained in:
Ievgen Sorokopud 2024-10-07 23:41:20 +02:00 committed by GitHub
parent b4d52e440e
commit 1ee648d672
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 452 additions and 74 deletions

View file

@ -10,6 +10,16 @@ import { useCallback, useRef, useState } from 'react';
import { ApiConfig, Replacements } from '@kbn/elastic-assistant-common';
import { useAssistantContext } from '../../assistant_context';
import { fetchConnectorExecuteAction, FetchConnectorExecuteResponse } from '../api';
import * as i18n from './translations';
/**
* TODO: This is a workaround to solve the issue with the long standing server tasks while cahtting with the assistant.
* Some models (like Llama 3.1 70B) can perform poorly and be slow which leads to a long time to handle the request.
* The `core-http-browser` has a timeout of two minutes after which it will re-try the request. In combination with the slow model it can lead to
* a situation where core http client will initiate same request again and again.
* To avoid this, we abort http request after timeout which is slightly below two minutes.
*/
const EXECUTE_ACTION_TIMEOUT = 110 * 1000; // in milliseconds
interface SendMessageProps {
apiConfig: ApiConfig;
@ -38,6 +48,11 @@ export const useSendMessage = (): UseSendMessage => {
async ({ apiConfig, http, message, conversationId, replacements }: SendMessageProps) => {
setIsLoading(true);
const timeoutId = setTimeout(() => {
abortController.current.abort(i18n.FETCH_MESSAGE_TIMEOUT_ERROR);
abortController.current = new AbortController();
}, EXECUTE_ACTION_TIMEOUT);
try {
return await fetchConnectorExecuteAction({
conversationId,
@ -52,6 +67,7 @@ export const useSendMessage = (): UseSendMessage => {
traceOptions,
});
} finally {
clearTimeout(timeoutId);
setIsLoading(false);
}
},

View file

@ -0,0 +1,15 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { i18n } from '@kbn/i18n';
export const FETCH_MESSAGE_TIMEOUT_ERROR = i18n.translate(
'xpack.elasticAssistant.assistant.useSendMessage.fetchMessageTimeoutError',
{
defaultMessage: 'Assistant could not respond in time. Please try again later.',
}
);

View file

@ -45,6 +45,7 @@ export interface AgentExecutorParams<T extends boolean> {
esClient: ElasticsearchClient;
langChainMessages: BaseMessage[];
llmType?: string;
isOssModel?: boolean;
logger: Logger;
inference: InferenceServerStart;
onNewReplacements?: (newReplacements: Replacements) => void;

View file

@ -94,6 +94,10 @@ export const getDefaultAssistantGraph = ({
value: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
isOssModel: {
value: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
conversation: {
value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
y ?? x,

View file

@ -24,6 +24,7 @@ interface StreamGraphParams {
assistantGraph: DefaultAssistantGraph;
inputs: GraphInputs;
logger: Logger;
isOssModel?: boolean;
onLlmResponse?: OnLlmResponse;
request: KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
traceOptions?: TraceOptions;
@ -36,6 +37,7 @@ interface StreamGraphParams {
* @param assistantGraph
* @param inputs
* @param logger
* @param isOssModel
* @param onLlmResponse
* @param request
* @param traceOptions
@ -45,6 +47,7 @@ export const streamGraph = async ({
assistantGraph,
inputs,
logger,
isOssModel,
onLlmResponse,
request,
traceOptions,
@ -80,8 +83,8 @@ export const streamGraph = async ({
};
if (
(inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') &&
inputs?.bedrockChatEnabled
inputs.isOssModel ||
((inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') && inputs?.bedrockChatEnabled)
) {
const stream = await assistantGraph.streamEvents(
inputs,
@ -92,7 +95,9 @@ export const streamGraph = async ({
version: 'v2',
streamMode: 'values',
},
inputs?.llmType === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
inputs.isOssModel || inputs?.llmType === 'bedrock'
? { includeNames: ['Summarizer'] }
: undefined
);
for await (const { event, data, tags } of stream) {

View file

@ -36,6 +36,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
inference,
langChainMessages,
llmType,
isOssModel,
logger: parentLogger,
isStream = false,
onLlmResponse,
@ -48,7 +49,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
responseLanguage = 'English',
}) => {
const logger = parentLogger.get('defaultAssistantGraph');
const isOpenAI = llmType === 'openai';
const isOpenAI = llmType === 'openai' && !isOssModel;
const llmClass = getLlmClass(llmType, bedrockChatEnabled);
/**
@ -111,7 +112,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
};
const tools: StructuredTool[] = assistantTools.flatMap(
(tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance() }) ?? []
(tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance(), isOssModel }) ?? []
);
// If KB enabled, fetch for any KB IndexEntries and generate a tool for each
@ -166,6 +167,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
conversationId,
llmType,
isStream,
isOssModel,
input: latestMessage[0]?.content as string,
};
@ -175,6 +177,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
assistantGraph,
inputs,
logger,
isOssModel,
onLlmResponse,
request,
traceOptions,

View file

@ -22,7 +22,9 @@ interface ModelInputParams extends NodeParamsBase {
export function modelInput({ logger, state }: ModelInputParams): Partial<AgentState> {
logger.debug(() => `${NodeType.MODEL_INPUT}: Node state:\n${JSON.stringify(state, null, 2)}`);
const hasRespondStep = state.isStream && state.bedrockChatEnabled && state.llmType === 'bedrock';
const hasRespondStep =
state.isStream &&
(state.isOssModel || (state.bedrockChatEnabled && state.llmType === 'bedrock'));
return {
hasRespondStep,

View file

@ -18,3 +18,59 @@ const KB_CATCH =
export const GEMINI_SYSTEM_PROMPT = `${BASE_GEMINI_PROMPT} ${KB_CATCH}`;
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from ESQLKnowledgeBaseTool as is. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response.`;
export const GEMINI_USER_PROMPT = `Now, always using the tools at your disposal, step by step, come up with a response to this request:\n\n`;
export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. You have access to the following tools:
{tools}
The tool action_input should ALWAYS follow the tool JSON schema args.
Valid "action" values: "Final Answer" or {tool_names}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).
Provide only ONE action per $JSON_BLOB, as shown:
\`\`\`
{{
"action": $TOOL_NAME,
"action_input": $TOOL_INPUT
}}
\`\`\`
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
\`\`\`
$JSON_BLOB
\`\`\`
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
\`\`\`
{{
"action": "Final Answer",
"action_input": "Final response to human"}}
Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`;

View file

@ -11,6 +11,7 @@ import {
DEFAULT_SYSTEM_PROMPT,
GEMINI_SYSTEM_PROMPT,
GEMINI_USER_PROMPT,
STRUCTURED_SYSTEM_PROMPT,
} from './nodes/translations';
export const formatPrompt = (prompt: string, additionalPrompt?: string) =>
@ -26,61 +27,7 @@ export const systemPrompts = {
bedrock: `${DEFAULT_SYSTEM_PROMPT} ${BEDROCK_SYSTEM_PROMPT}`,
// The default prompt overwhelms gemini, do not prepend
gemini: GEMINI_SYSTEM_PROMPT,
structuredChat: `Respond to the human as helpfully and accurately as possible. You have access to the following tools:
{tools}
The tool action_input should ALWAYS follow the tool JSON schema args.
Valid "action" values: "Final Answer" or {tool_names}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).
Provide only ONE action per $JSON_BLOB, as shown:
\`\`\`
{{
"action": $TOOL_NAME,
"action_input": $TOOL_INPUT
}}
\`\`\`
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
\`\`\`
$JSON_BLOB
\`\`\`
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
\`\`\`
{{
"action": "Final Answer",
"action_input": "Final response to human"}}
Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`,
structuredChat: STRUCTURED_SYSTEM_PROMPT,
};
export const openAIFunctionAgentPrompt = formatPrompt(systemPrompts.openai);

View file

@ -20,6 +20,7 @@ export interface GraphInputs {
conversationId?: string;
llmType?: string;
isStream?: boolean;
isOssModel?: boolean;
input: string;
responseLanguage?: string;
}
@ -31,6 +32,7 @@ export interface AgentState extends AgentStateBase {
lastNode: string;
hasRespondStep: boolean;
isStream: boolean;
isOssModel: boolean;
bedrockChatEnabled: boolean;
llmType: string;
responseLanguage: string;

View file

@ -30,6 +30,7 @@ import {
} from '../helpers';
import { transformESSearchToAnonymizationFields } from '../../ai_assistant_data_clients/anonymization_fields/helpers';
import { EsAnonymizationFieldsSchema } from '../../ai_assistant_data_clients/anonymization_fields/types';
import { isOpenSourceModel } from '../utils';
export const SYSTEM_PROMPT_CONTEXT_NON_I18N = (context: string) => {
return `CONTEXT:\n"""\n${context}\n"""`;
@ -99,7 +100,9 @@ export const chatCompleteRoute = (
const actions = ctx.elasticAssistant.actions;
const actionsClient = await actions.getActionsClientWithRequest(request);
const connectors = await actionsClient.getBulk({ ids: [connectorId] });
actionTypeId = connectors.length > 0 ? connectors[0].actionTypeId : '.gen-ai';
const connector = connectors.length > 0 ? connectors[0] : undefined;
actionTypeId = connector?.actionTypeId ?? '.gen-ai';
const isOssModel = isOpenSourceModel(connector);
// replacements
const anonymizationFieldsRes =
@ -192,6 +195,7 @@ export const chatCompleteRoute = (
actionsClient,
actionTypeId,
connectorId,
isOssModel,
conversationId: conversationId ?? newConversation?.id,
context: ctx,
getElser,

View file

@ -46,7 +46,7 @@ import {
openAIFunctionAgentPrompt,
structuredChatAgentPrompt,
} from '../../lib/langchain/graphs/default_assistant_graph/prompts';
import { getLlmClass, getLlmType } from '../utils';
import { getLlmClass, getLlmType, isOpenSourceModel } from '../utils';
const DEFAULT_SIZE = 20;
const ROUTE_HANDLER_TIMEOUT = 10 * 60 * 1000; // 10 * 60 seconds = 10 minutes
@ -174,10 +174,12 @@ export const postEvaluateRoute = (
name: string;
graph: DefaultAssistantGraph;
llmType: string | undefined;
isOssModel: boolean | undefined;
}> = await Promise.all(
connectors.map(async (connector) => {
const llmType = getLlmType(connector.actionTypeId);
const isOpenAI = llmType === 'openai';
const isOssModel = isOpenSourceModel(connector);
const isOpenAI = llmType === 'openai' && !isOssModel;
const llmClass = getLlmClass(llmType, true);
const createLlmInstance = () =>
new llmClass({
@ -232,6 +234,7 @@ export const postEvaluateRoute = (
isEnabledKnowledgeBase,
kbDataClient: dataClients?.kbDataClient,
llm,
isOssModel,
logger,
modelExists: isEnabledKnowledgeBase,
request: skeletonRequest,
@ -274,6 +277,7 @@ export const postEvaluateRoute = (
return {
name: `${runName} - ${connector.name}`,
llmType,
isOssModel,
graph: getDefaultAssistantGraph({
agentRunnable,
dataClients,
@ -287,7 +291,7 @@ export const postEvaluateRoute = (
);
// Run an evaluation for each graph so they show up separately (resulting in each dataset run grouped by connector)
await asyncForEach(graphs, async ({ name, graph, llmType }) => {
await asyncForEach(graphs, async ({ name, graph, llmType, isOssModel }) => {
// Wrapper function for invoking the graph (to parse different input/output formats)
const predict = async (input: { input: string }) => {
logger.debug(`input:\n ${JSON.stringify(input, null, 2)}`);
@ -300,6 +304,7 @@ export const postEvaluateRoute = (
llmType,
bedrockChatEnabled: true,
isStreaming: false,
isOssModel,
}, // TODO: Update to use the correct input format per dataset type
{
runName,
@ -310,15 +315,20 @@ export const postEvaluateRoute = (
return output;
};
const evalOutput = await evaluate(predict, {
evaluate(predict, {
data: datasetName ?? '',
evaluators: [], // Evals to be managed in LangSmith for now
experimentPrefix: name,
client: new Client({ apiKey: langSmithApiKey }),
// prevent rate limiting and unexpected multiple experiment runs
maxConcurrency: 5,
});
logger.debug(`runResp:\n ${JSON.stringify(evalOutput, null, 2)}`);
})
.then((output) => {
logger.debug(`runResp:\n ${JSON.stringify(output, null, 2)}`);
})
.catch((err) => {
logger.error(`evaluation error:\n ${JSON.stringify(err, null, 2)}`);
});
});
return response.ok({

View file

@ -322,6 +322,7 @@ export interface LangChainExecuteParams {
actionTypeId: string;
connectorId: string;
inference: InferenceServerStart;
isOssModel?: boolean;
conversationId?: string;
context: AwaitedProperties<
Pick<ElasticAssistantRequestHandlerContext, 'elasticAssistant' | 'licensing' | 'core'>
@ -348,6 +349,7 @@ export const langChainExecute = async ({
telemetry,
actionTypeId,
connectorId,
isOssModel,
context,
actionsClient,
inference,
@ -412,6 +414,7 @@ export const langChainExecute = async ({
inference,
isStream,
llmType: getLlmType(actionTypeId),
isOssModel,
langChainMessages,
logger,
onNewReplacements,

View file

@ -29,6 +29,7 @@ import {
getSystemPromptFromUserConversation,
langChainExecute,
} from './helpers';
import { isOpenSourceModel } from './utils';
export const postActionsConnectorExecuteRoute = (
router: IRouter<ElasticAssistantRequestHandlerContext>,
@ -94,6 +95,9 @@ export const postActionsConnectorExecuteRoute = (
const actions = ctx.elasticAssistant.actions;
const inference = ctx.elasticAssistant.inference;
const actionsClient = await actions.getActionsClientWithRequest(request);
const connectors = await actionsClient.getBulk({ ids: [connectorId] });
const connector = connectors.length > 0 ? connectors[0] : undefined;
const isOssModel = isOpenSourceModel(connector);
const conversationsDataClient =
await assistantContext.getAIAssistantConversationsDataClient();
@ -129,6 +133,7 @@ export const postActionsConnectorExecuteRoute = (
actionsClient,
actionTypeId,
connectorId,
isOssModel,
conversationId,
context: ctx,
getElser,

View file

@ -0,0 +1,69 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Connector } from '@kbn/actions-plugin/server/application/connector/types';
import { isOpenSourceModel } from './utils';
import {
OPENAI_CHAT_URL,
OpenAiProviderType,
} from '@kbn/stack-connectors-plugin/common/openai/constants';
describe('Utils', () => {
describe('isOpenSourceModel', () => {
it('should return `false` when connector is undefined', async () => {
const isOpenModel = isOpenSourceModel();
expect(isOpenModel).toEqual(false);
});
it('should return `false` when connector is a Bedrock', async () => {
const connector = { actionTypeId: '.bedrock' } as Connector;
const isOpenModel = isOpenSourceModel(connector);
expect(isOpenModel).toEqual(false);
});
it('should return `false` when connector is a Gemini', async () => {
const connector = { actionTypeId: '.gemini' } as Connector;
const isOpenModel = isOpenSourceModel(connector);
expect(isOpenModel).toEqual(false);
});
it('should return `false` when connector is a OpenAI and API url is not specified', async () => {
const connector = {
actionTypeId: '.gen-ai',
} as unknown as Connector;
const isOpenModel = isOpenSourceModel(connector);
expect(isOpenModel).toEqual(false);
});
it('should return `false` when connector is a OpenAI and OpenAI API url is specified', async () => {
const connector = {
actionTypeId: '.gen-ai',
config: { apiUrl: OPENAI_CHAT_URL },
} as unknown as Connector;
const isOpenModel = isOpenSourceModel(connector);
expect(isOpenModel).toEqual(false);
});
it('should return `false` when connector is a AzureOpenAI', async () => {
const connector = {
actionTypeId: '.gen-ai',
config: { apiProvider: OpenAiProviderType.AzureAi },
} as unknown as Connector;
const isOpenModel = isOpenSourceModel(connector);
expect(isOpenModel).toEqual(false);
});
it('should return `true` when connector is a OpenAI and non-OpenAI API url is specified', async () => {
const connector = {
actionTypeId: '.gen-ai',
config: { apiUrl: 'https://elastic.llm.com/llama/chat/completions' },
} as unknown as Connector;
const isOpenModel = isOpenSourceModel(connector);
expect(isOpenModel).toEqual(true);
});
});
});

View file

@ -19,6 +19,11 @@ import {
ActionsClientSimpleChatModel,
ActionsClientChatVertexAI,
} from '@kbn/langchain/server';
import { Connector } from '@kbn/actions-plugin/server/application/connector/types';
import {
OPENAI_CHAT_URL,
OpenAiProviderType,
} from '@kbn/stack-connectors-plugin/common/openai/constants';
import { CustomHttpRequestError } from './custom_http_request_error';
export interface OutputError {
@ -189,3 +194,26 @@ export const getLlmClass = (llmType?: string, bedrockChatEnabled?: boolean) =>
: llmType === 'gemini' && bedrockChatEnabled
? ActionsClientChatVertexAI
: ActionsClientSimpleChatModel;
export const isOpenSourceModel = (connector?: Connector): boolean => {
if (connector == null) {
return false;
}
const llmType = getLlmType(connector.actionTypeId);
const connectorApiUrl = connector.config?.apiUrl
? (connector.config.apiUrl as string)
: undefined;
const connectorApiProvider = connector.config?.apiProvider
? (connector.config?.apiProvider as OpenAiProviderType)
: undefined;
const isOpenAiType = llmType === 'openai';
const isOpenAI =
isOpenAiType &&
(!connectorApiUrl ||
connectorApiUrl === OPENAI_CHAT_URL ||
connectorApiProvider === OpenAiProviderType.AzureAi);
return isOpenAiType && !isOpenAI;
};

View file

@ -244,6 +244,7 @@ export interface AssistantToolParams {
kbDataClient?: AIAssistantKnowledgeBaseDataClient;
langChainTimeout?: number;
llm?: ActionsClientLlm | AssistantToolLlm;
isOssModel?: boolean;
logger: Logger;
modelExists: boolean;
onNewReplacements?: (newReplacements: Replacements) => void;

View file

@ -0,0 +1,15 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export const getPromptSuffixForOssModel = (toolName: string) => `
When using ${toolName} tool ALWAYS pass the user's questions directly as input into the tool.
Always return value from ${toolName} tool as is.
The ES|QL query should ALWAYS be wrapped in triple backticks ("\`\`\`esql"). Add a new line character right before the triple backticks.
It is important that ES|QL query is preceeded by a new line.`;

View file

@ -12,6 +12,7 @@ import type { KibanaRequest } from '@kbn/core-http-server';
import type { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common/impl/schemas/actions_connector/post_actions_connector_execute_route.gen';
import { loggerMock } from '@kbn/logging-mocks';
import type { AIAssistantKnowledgeBaseDataClient } from '@kbn/elastic-assistant-plugin/server/ai_assistant_data_clients/knowledge_base';
import { getPromptSuffixForOssModel } from './common';
describe('EsqlLanguageKnowledgeBaseTool', () => {
const kbDataClient = jest.fn() as unknown as AIAssistantKnowledgeBaseDataClient;
@ -108,5 +109,27 @@ describe('EsqlLanguageKnowledgeBaseTool', () => {
expect(tool.tags).toEqual(['esql', 'query-generation', 'knowledge-base']);
});
it('should return tool with the expected description for OSS model', () => {
const tool = ESQL_KNOWLEDGE_BASE_TOOL.getTool({
isEnabledKnowledgeBase: true,
modelExists: true,
isOssModel: true,
...rest,
}) as DynamicTool;
expect(tool.description).toContain(getPromptSuffixForOssModel('ESQLKnowledgeBaseTool'));
});
it('should return tool with the expected description for non-OSS model', () => {
const tool = ESQL_KNOWLEDGE_BASE_TOOL.getTool({
isEnabledKnowledgeBase: true,
modelExists: true,
isOssModel: false,
...rest,
}) as DynamicTool;
expect(tool.description).not.toContain(getPromptSuffixForOssModel('ESQLKnowledgeBaseTool'));
});
});
});

View file

@ -14,12 +14,15 @@ import { z } from '@kbn/zod';
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
import { ESQL_RESOURCE } from '@kbn/elastic-assistant-plugin/server/routes/knowledge_base/constants';
import { APP_UI_ID } from '../../../../common';
import { getPromptSuffixForOssModel } from './common';
const TOOL_NAME = 'ESQLKnowledgeBaseTool';
const toolDetails = {
id: 'esql-knowledge-base-tool',
name: TOOL_NAME,
description:
'Call this for knowledge on how to build an ESQL query, or answer questions about the ES|QL query language. Input must always be the user query on a single line, with no other text. Your answer will be parsed as JSON, so never use quotes within the output and instead use backticks. Do not add any additional text to describe your output.',
id: 'esql-knowledge-base-tool',
name: 'ESQLKnowledgeBaseTool',
};
export const ESQL_KNOWLEDGE_BASE_TOOL: AssistantTool = {
...toolDetails,
@ -31,12 +34,13 @@ export const ESQL_KNOWLEDGE_BASE_TOOL: AssistantTool = {
getTool(params: AssistantToolParams) {
if (!this.isSupported(params)) return null;
const { kbDataClient } = params as AssistantToolParams;
const { kbDataClient, isOssModel } = params as AssistantToolParams;
if (kbDataClient == null) return null;
return new DynamicStructuredTool({
name: toolDetails.name,
description: toolDetails.description,
description:
toolDetails.description + (isOssModel ? getPromptSuffixForOssModel(TOOL_NAME) : ''),
schema: z.object({
question: z.string().describe(`The user's exact question about ESQL`),
}),

View file

@ -0,0 +1,162 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { RetrievalQAChain } from 'langchain/chains';
import type { DynamicTool } from '@langchain/core/tools';
import { NL_TO_ESQL_TOOL } from './nl_to_esql_tool';
import type { ElasticsearchClient } from '@kbn/core-elasticsearch-server';
import type { KibanaRequest } from '@kbn/core-http-server';
import type { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common/impl/schemas/actions_connector/post_actions_connector_execute_route.gen';
import { loggerMock } from '@kbn/logging-mocks';
import { getPromptSuffixForOssModel } from './common';
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
describe('NaturalLanguageESQLTool', () => {
const chain = {} as RetrievalQAChain;
const esClient = {
search: jest.fn().mockResolvedValue({}),
} as unknown as ElasticsearchClient;
const request = {
body: {
isEnabledKnowledgeBase: false,
alertsIndexPattern: '.alerts-security.alerts-default',
allow: ['@timestamp', 'cloud.availability_zone', 'user.name'],
allowReplacement: ['user.name'],
replacements: { key: 'value' },
size: 20,
},
} as unknown as KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
const logger = loggerMock.create();
const inference = {} as InferenceServerStart;
const connectorId = 'fake-connector';
const rest = {
chain,
esClient,
logger,
request,
inference,
connectorId,
};
describe('isSupported', () => {
it('returns false if isEnabledKnowledgeBase is false', () => {
const params = {
isEnabledKnowledgeBase: false,
modelExists: true,
...rest,
};
expect(NL_TO_ESQL_TOOL.isSupported(params)).toBe(false);
});
it('returns false if modelExists is false (the ELSER model is not installed)', () => {
const params = {
isEnabledKnowledgeBase: true,
modelExists: false, // <-- ELSER model is not installed
...rest,
};
expect(NL_TO_ESQL_TOOL.isSupported(params)).toBe(false);
});
it('returns true if isEnabledKnowledgeBase and modelExists are true', () => {
const params = {
isEnabledKnowledgeBase: true,
modelExists: true,
...rest,
};
expect(NL_TO_ESQL_TOOL.isSupported(params)).toBe(true);
});
});
describe('getTool', () => {
it('returns null if isEnabledKnowledgeBase is false', () => {
const tool = NL_TO_ESQL_TOOL.getTool({
isEnabledKnowledgeBase: false,
modelExists: true,
...rest,
});
expect(tool).toBeNull();
});
it('returns null if modelExists is false (the ELSER model is not installed)', () => {
const tool = NL_TO_ESQL_TOOL.getTool({
isEnabledKnowledgeBase: true,
modelExists: false, // <-- ELSER model is not installed
...rest,
});
expect(tool).toBeNull();
});
it('returns null if inference plugin is not provided', () => {
const tool = NL_TO_ESQL_TOOL.getTool({
isEnabledKnowledgeBase: true,
modelExists: true,
...rest,
inference: undefined,
});
expect(tool).toBeNull();
});
it('returns null if connectorId is not provided', () => {
const tool = NL_TO_ESQL_TOOL.getTool({
isEnabledKnowledgeBase: true,
modelExists: true,
...rest,
connectorId: undefined,
});
expect(tool).toBeNull();
});
it('should return a Tool instance if isEnabledKnowledgeBase and modelExists are true', () => {
const tool = NL_TO_ESQL_TOOL.getTool({
isEnabledKnowledgeBase: true,
modelExists: true,
...rest,
});
expect(tool?.name).toEqual('NaturalLanguageESQLTool');
});
it('should return a tool with the expected tags', () => {
const tool = NL_TO_ESQL_TOOL.getTool({
isEnabledKnowledgeBase: true,
modelExists: true,
...rest,
}) as DynamicTool;
expect(tool.tags).toEqual(['esql', 'query-generation', 'knowledge-base']);
});
it('should return tool with the expected description for OSS model', () => {
const tool = NL_TO_ESQL_TOOL.getTool({
isEnabledKnowledgeBase: true,
modelExists: true,
isOssModel: true,
...rest,
}) as DynamicTool;
expect(tool.description).toContain(getPromptSuffixForOssModel('NaturalLanguageESQLTool'));
});
it('should return tool with the expected description for non-OSS model', () => {
const tool = NL_TO_ESQL_TOOL.getTool({
isEnabledKnowledgeBase: true,
modelExists: true,
isOssModel: false,
...rest,
}) as DynamicTool;
expect(tool.description).not.toContain(getPromptSuffixForOssModel('NaturalLanguageESQLTool'));
});
});
});

View file

@ -11,6 +11,7 @@ import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-
import { lastValueFrom } from 'rxjs';
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
import { APP_UI_ID } from '../../../../common';
import { getPromptSuffixForOssModel } from './common';
export type ESQLToolParams = AssistantToolParams;
@ -37,7 +38,7 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
getTool(params: ESQLToolParams) {
if (!this.isSupported(params)) return null;
const { connectorId, inference, logger, request } = params as ESQLToolParams;
const { connectorId, inference, logger, request, isOssModel } = params as ESQLToolParams;
if (inference == null || connectorId == null) return null;
const callNaturalLanguageToEsql = async (question: string) => {
@ -46,6 +47,7 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
client: inference.getClient({ request }),
connectorId,
input: question,
...(isOssModel ? { functionCalling: 'simulated' } : {}),
logger: {
debug: (source) => {
logger.debug(typeof source === 'function' ? source() : source);
@ -57,7 +59,8 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
return new DynamicStructuredTool({
name: toolDetails.name,
description: toolDetails.description,
description:
toolDetails.description + (isOssModel ? getPromptSuffixForOssModel(TOOL_NAME) : ''),
schema: z.object({
question: z.string().describe(`The user's exact question about ESQL`),
}),