mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 17:28:26 -04:00
This commit is contained in:
parent
b4d52e440e
commit
1ee648d672
22 changed files with 452 additions and 74 deletions
|
@ -10,6 +10,16 @@ import { useCallback, useRef, useState } from 'react';
|
|||
import { ApiConfig, Replacements } from '@kbn/elastic-assistant-common';
|
||||
import { useAssistantContext } from '../../assistant_context';
|
||||
import { fetchConnectorExecuteAction, FetchConnectorExecuteResponse } from '../api';
|
||||
import * as i18n from './translations';
|
||||
|
||||
/**
|
||||
* TODO: This is a workaround to solve the issue with the long standing server tasks while cahtting with the assistant.
|
||||
* Some models (like Llama 3.1 70B) can perform poorly and be slow which leads to a long time to handle the request.
|
||||
* The `core-http-browser` has a timeout of two minutes after which it will re-try the request. In combination with the slow model it can lead to
|
||||
* a situation where core http client will initiate same request again and again.
|
||||
* To avoid this, we abort http request after timeout which is slightly below two minutes.
|
||||
*/
|
||||
const EXECUTE_ACTION_TIMEOUT = 110 * 1000; // in milliseconds
|
||||
|
||||
interface SendMessageProps {
|
||||
apiConfig: ApiConfig;
|
||||
|
@ -38,6 +48,11 @@ export const useSendMessage = (): UseSendMessage => {
|
|||
async ({ apiConfig, http, message, conversationId, replacements }: SendMessageProps) => {
|
||||
setIsLoading(true);
|
||||
|
||||
const timeoutId = setTimeout(() => {
|
||||
abortController.current.abort(i18n.FETCH_MESSAGE_TIMEOUT_ERROR);
|
||||
abortController.current = new AbortController();
|
||||
}, EXECUTE_ACTION_TIMEOUT);
|
||||
|
||||
try {
|
||||
return await fetchConnectorExecuteAction({
|
||||
conversationId,
|
||||
|
@ -52,6 +67,7 @@ export const useSendMessage = (): UseSendMessage => {
|
|||
traceOptions,
|
||||
});
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
setIsLoading(false);
|
||||
}
|
||||
},
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { i18n } from '@kbn/i18n';
|
||||
|
||||
export const FETCH_MESSAGE_TIMEOUT_ERROR = i18n.translate(
|
||||
'xpack.elasticAssistant.assistant.useSendMessage.fetchMessageTimeoutError',
|
||||
{
|
||||
defaultMessage: 'Assistant could not respond in time. Please try again later.',
|
||||
}
|
||||
);
|
|
@ -45,6 +45,7 @@ export interface AgentExecutorParams<T extends boolean> {
|
|||
esClient: ElasticsearchClient;
|
||||
langChainMessages: BaseMessage[];
|
||||
llmType?: string;
|
||||
isOssModel?: boolean;
|
||||
logger: Logger;
|
||||
inference: InferenceServerStart;
|
||||
onNewReplacements?: (newReplacements: Replacements) => void;
|
||||
|
|
|
@ -94,6 +94,10 @@ export const getDefaultAssistantGraph = ({
|
|||
value: (x: boolean, y?: boolean) => y ?? x,
|
||||
default: () => false,
|
||||
},
|
||||
isOssModel: {
|
||||
value: (x: boolean, y?: boolean) => y ?? x,
|
||||
default: () => false,
|
||||
},
|
||||
conversation: {
|
||||
value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
|
||||
y ?? x,
|
||||
|
|
|
@ -24,6 +24,7 @@ interface StreamGraphParams {
|
|||
assistantGraph: DefaultAssistantGraph;
|
||||
inputs: GraphInputs;
|
||||
logger: Logger;
|
||||
isOssModel?: boolean;
|
||||
onLlmResponse?: OnLlmResponse;
|
||||
request: KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
|
||||
traceOptions?: TraceOptions;
|
||||
|
@ -36,6 +37,7 @@ interface StreamGraphParams {
|
|||
* @param assistantGraph
|
||||
* @param inputs
|
||||
* @param logger
|
||||
* @param isOssModel
|
||||
* @param onLlmResponse
|
||||
* @param request
|
||||
* @param traceOptions
|
||||
|
@ -45,6 +47,7 @@ export const streamGraph = async ({
|
|||
assistantGraph,
|
||||
inputs,
|
||||
logger,
|
||||
isOssModel,
|
||||
onLlmResponse,
|
||||
request,
|
||||
traceOptions,
|
||||
|
@ -80,8 +83,8 @@ export const streamGraph = async ({
|
|||
};
|
||||
|
||||
if (
|
||||
(inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') &&
|
||||
inputs?.bedrockChatEnabled
|
||||
inputs.isOssModel ||
|
||||
((inputs?.llmType === 'bedrock' || inputs?.llmType === 'gemini') && inputs?.bedrockChatEnabled)
|
||||
) {
|
||||
const stream = await assistantGraph.streamEvents(
|
||||
inputs,
|
||||
|
@ -92,7 +95,9 @@ export const streamGraph = async ({
|
|||
version: 'v2',
|
||||
streamMode: 'values',
|
||||
},
|
||||
inputs?.llmType === 'bedrock' ? { includeNames: ['Summarizer'] } : undefined
|
||||
inputs.isOssModel || inputs?.llmType === 'bedrock'
|
||||
? { includeNames: ['Summarizer'] }
|
||||
: undefined
|
||||
);
|
||||
|
||||
for await (const { event, data, tags } of stream) {
|
||||
|
|
|
@ -36,6 +36,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
inference,
|
||||
langChainMessages,
|
||||
llmType,
|
||||
isOssModel,
|
||||
logger: parentLogger,
|
||||
isStream = false,
|
||||
onLlmResponse,
|
||||
|
@ -48,7 +49,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
responseLanguage = 'English',
|
||||
}) => {
|
||||
const logger = parentLogger.get('defaultAssistantGraph');
|
||||
const isOpenAI = llmType === 'openai';
|
||||
const isOpenAI = llmType === 'openai' && !isOssModel;
|
||||
const llmClass = getLlmClass(llmType, bedrockChatEnabled);
|
||||
|
||||
/**
|
||||
|
@ -111,7 +112,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
};
|
||||
|
||||
const tools: StructuredTool[] = assistantTools.flatMap(
|
||||
(tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance() }) ?? []
|
||||
(tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance(), isOssModel }) ?? []
|
||||
);
|
||||
|
||||
// If KB enabled, fetch for any KB IndexEntries and generate a tool for each
|
||||
|
@ -166,6 +167,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
conversationId,
|
||||
llmType,
|
||||
isStream,
|
||||
isOssModel,
|
||||
input: latestMessage[0]?.content as string,
|
||||
};
|
||||
|
||||
|
@ -175,6 +177,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
assistantGraph,
|
||||
inputs,
|
||||
logger,
|
||||
isOssModel,
|
||||
onLlmResponse,
|
||||
request,
|
||||
traceOptions,
|
||||
|
|
|
@ -22,7 +22,9 @@ interface ModelInputParams extends NodeParamsBase {
|
|||
export function modelInput({ logger, state }: ModelInputParams): Partial<AgentState> {
|
||||
logger.debug(() => `${NodeType.MODEL_INPUT}: Node state:\n${JSON.stringify(state, null, 2)}`);
|
||||
|
||||
const hasRespondStep = state.isStream && state.bedrockChatEnabled && state.llmType === 'bedrock';
|
||||
const hasRespondStep =
|
||||
state.isStream &&
|
||||
(state.isOssModel || (state.bedrockChatEnabled && state.llmType === 'bedrock'));
|
||||
|
||||
return {
|
||||
hasRespondStep,
|
||||
|
|
|
@ -18,3 +18,59 @@ const KB_CATCH =
|
|||
export const GEMINI_SYSTEM_PROMPT = `${BASE_GEMINI_PROMPT} ${KB_CATCH}`;
|
||||
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from ESQLKnowledgeBaseTool as is. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response.`;
|
||||
export const GEMINI_USER_PROMPT = `Now, always using the tools at your disposal, step by step, come up with a response to this request:\n\n`;
|
||||
|
||||
export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. You have access to the following tools:
|
||||
|
||||
{tools}
|
||||
|
||||
The tool action_input should ALWAYS follow the tool JSON schema args.
|
||||
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
\`\`\`
|
||||
|
||||
{{
|
||||
|
||||
"action": $TOOL_NAME,
|
||||
|
||||
"action_input": $TOOL_INPUT
|
||||
|
||||
}}
|
||||
|
||||
\`\`\`
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
|
||||
Thought: consider previous and subsequent steps
|
||||
|
||||
Action:
|
||||
|
||||
\`\`\`
|
||||
|
||||
$JSON_BLOB
|
||||
|
||||
\`\`\`
|
||||
|
||||
Observation: action result
|
||||
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
|
||||
Thought: I know what to respond
|
||||
|
||||
Action:
|
||||
|
||||
\`\`\`
|
||||
|
||||
{{
|
||||
|
||||
"action": "Final Answer",
|
||||
|
||||
"action_input": "Final response to human"}}
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`;
|
||||
|
|
|
@ -11,6 +11,7 @@ import {
|
|||
DEFAULT_SYSTEM_PROMPT,
|
||||
GEMINI_SYSTEM_PROMPT,
|
||||
GEMINI_USER_PROMPT,
|
||||
STRUCTURED_SYSTEM_PROMPT,
|
||||
} from './nodes/translations';
|
||||
|
||||
export const formatPrompt = (prompt: string, additionalPrompt?: string) =>
|
||||
|
@ -26,61 +27,7 @@ export const systemPrompts = {
|
|||
bedrock: `${DEFAULT_SYSTEM_PROMPT} ${BEDROCK_SYSTEM_PROMPT}`,
|
||||
// The default prompt overwhelms gemini, do not prepend
|
||||
gemini: GEMINI_SYSTEM_PROMPT,
|
||||
structuredChat: `Respond to the human as helpfully and accurately as possible. You have access to the following tools:
|
||||
|
||||
{tools}
|
||||
|
||||
The tool action_input should ALWAYS follow the tool JSON schema args.
|
||||
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
\`\`\`
|
||||
|
||||
{{
|
||||
|
||||
"action": $TOOL_NAME,
|
||||
|
||||
"action_input": $TOOL_INPUT
|
||||
|
||||
}}
|
||||
|
||||
\`\`\`
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
|
||||
Thought: consider previous and subsequent steps
|
||||
|
||||
Action:
|
||||
|
||||
\`\`\`
|
||||
|
||||
$JSON_BLOB
|
||||
|
||||
\`\`\`
|
||||
|
||||
Observation: action result
|
||||
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
|
||||
Thought: I know what to respond
|
||||
|
||||
Action:
|
||||
|
||||
\`\`\`
|
||||
|
||||
{{
|
||||
|
||||
"action": "Final Answer",
|
||||
|
||||
"action_input": "Final response to human"}}
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`,
|
||||
structuredChat: STRUCTURED_SYSTEM_PROMPT,
|
||||
};
|
||||
|
||||
export const openAIFunctionAgentPrompt = formatPrompt(systemPrompts.openai);
|
||||
|
|
|
@ -20,6 +20,7 @@ export interface GraphInputs {
|
|||
conversationId?: string;
|
||||
llmType?: string;
|
||||
isStream?: boolean;
|
||||
isOssModel?: boolean;
|
||||
input: string;
|
||||
responseLanguage?: string;
|
||||
}
|
||||
|
@ -31,6 +32,7 @@ export interface AgentState extends AgentStateBase {
|
|||
lastNode: string;
|
||||
hasRespondStep: boolean;
|
||||
isStream: boolean;
|
||||
isOssModel: boolean;
|
||||
bedrockChatEnabled: boolean;
|
||||
llmType: string;
|
||||
responseLanguage: string;
|
||||
|
|
|
@ -30,6 +30,7 @@ import {
|
|||
} from '../helpers';
|
||||
import { transformESSearchToAnonymizationFields } from '../../ai_assistant_data_clients/anonymization_fields/helpers';
|
||||
import { EsAnonymizationFieldsSchema } from '../../ai_assistant_data_clients/anonymization_fields/types';
|
||||
import { isOpenSourceModel } from '../utils';
|
||||
|
||||
export const SYSTEM_PROMPT_CONTEXT_NON_I18N = (context: string) => {
|
||||
return `CONTEXT:\n"""\n${context}\n"""`;
|
||||
|
@ -99,7 +100,9 @@ export const chatCompleteRoute = (
|
|||
const actions = ctx.elasticAssistant.actions;
|
||||
const actionsClient = await actions.getActionsClientWithRequest(request);
|
||||
const connectors = await actionsClient.getBulk({ ids: [connectorId] });
|
||||
actionTypeId = connectors.length > 0 ? connectors[0].actionTypeId : '.gen-ai';
|
||||
const connector = connectors.length > 0 ? connectors[0] : undefined;
|
||||
actionTypeId = connector?.actionTypeId ?? '.gen-ai';
|
||||
const isOssModel = isOpenSourceModel(connector);
|
||||
|
||||
// replacements
|
||||
const anonymizationFieldsRes =
|
||||
|
@ -192,6 +195,7 @@ export const chatCompleteRoute = (
|
|||
actionsClient,
|
||||
actionTypeId,
|
||||
connectorId,
|
||||
isOssModel,
|
||||
conversationId: conversationId ?? newConversation?.id,
|
||||
context: ctx,
|
||||
getElser,
|
||||
|
|
|
@ -46,7 +46,7 @@ import {
|
|||
openAIFunctionAgentPrompt,
|
||||
structuredChatAgentPrompt,
|
||||
} from '../../lib/langchain/graphs/default_assistant_graph/prompts';
|
||||
import { getLlmClass, getLlmType } from '../utils';
|
||||
import { getLlmClass, getLlmType, isOpenSourceModel } from '../utils';
|
||||
|
||||
const DEFAULT_SIZE = 20;
|
||||
const ROUTE_HANDLER_TIMEOUT = 10 * 60 * 1000; // 10 * 60 seconds = 10 minutes
|
||||
|
@ -174,10 +174,12 @@ export const postEvaluateRoute = (
|
|||
name: string;
|
||||
graph: DefaultAssistantGraph;
|
||||
llmType: string | undefined;
|
||||
isOssModel: boolean | undefined;
|
||||
}> = await Promise.all(
|
||||
connectors.map(async (connector) => {
|
||||
const llmType = getLlmType(connector.actionTypeId);
|
||||
const isOpenAI = llmType === 'openai';
|
||||
const isOssModel = isOpenSourceModel(connector);
|
||||
const isOpenAI = llmType === 'openai' && !isOssModel;
|
||||
const llmClass = getLlmClass(llmType, true);
|
||||
const createLlmInstance = () =>
|
||||
new llmClass({
|
||||
|
@ -232,6 +234,7 @@ export const postEvaluateRoute = (
|
|||
isEnabledKnowledgeBase,
|
||||
kbDataClient: dataClients?.kbDataClient,
|
||||
llm,
|
||||
isOssModel,
|
||||
logger,
|
||||
modelExists: isEnabledKnowledgeBase,
|
||||
request: skeletonRequest,
|
||||
|
@ -274,6 +277,7 @@ export const postEvaluateRoute = (
|
|||
return {
|
||||
name: `${runName} - ${connector.name}`,
|
||||
llmType,
|
||||
isOssModel,
|
||||
graph: getDefaultAssistantGraph({
|
||||
agentRunnable,
|
||||
dataClients,
|
||||
|
@ -287,7 +291,7 @@ export const postEvaluateRoute = (
|
|||
);
|
||||
|
||||
// Run an evaluation for each graph so they show up separately (resulting in each dataset run grouped by connector)
|
||||
await asyncForEach(graphs, async ({ name, graph, llmType }) => {
|
||||
await asyncForEach(graphs, async ({ name, graph, llmType, isOssModel }) => {
|
||||
// Wrapper function for invoking the graph (to parse different input/output formats)
|
||||
const predict = async (input: { input: string }) => {
|
||||
logger.debug(`input:\n ${JSON.stringify(input, null, 2)}`);
|
||||
|
@ -300,6 +304,7 @@ export const postEvaluateRoute = (
|
|||
llmType,
|
||||
bedrockChatEnabled: true,
|
||||
isStreaming: false,
|
||||
isOssModel,
|
||||
}, // TODO: Update to use the correct input format per dataset type
|
||||
{
|
||||
runName,
|
||||
|
@ -310,15 +315,20 @@ export const postEvaluateRoute = (
|
|||
return output;
|
||||
};
|
||||
|
||||
const evalOutput = await evaluate(predict, {
|
||||
evaluate(predict, {
|
||||
data: datasetName ?? '',
|
||||
evaluators: [], // Evals to be managed in LangSmith for now
|
||||
experimentPrefix: name,
|
||||
client: new Client({ apiKey: langSmithApiKey }),
|
||||
// prevent rate limiting and unexpected multiple experiment runs
|
||||
maxConcurrency: 5,
|
||||
});
|
||||
logger.debug(`runResp:\n ${JSON.stringify(evalOutput, null, 2)}`);
|
||||
})
|
||||
.then((output) => {
|
||||
logger.debug(`runResp:\n ${JSON.stringify(output, null, 2)}`);
|
||||
})
|
||||
.catch((err) => {
|
||||
logger.error(`evaluation error:\n ${JSON.stringify(err, null, 2)}`);
|
||||
});
|
||||
});
|
||||
|
||||
return response.ok({
|
||||
|
|
|
@ -322,6 +322,7 @@ export interface LangChainExecuteParams {
|
|||
actionTypeId: string;
|
||||
connectorId: string;
|
||||
inference: InferenceServerStart;
|
||||
isOssModel?: boolean;
|
||||
conversationId?: string;
|
||||
context: AwaitedProperties<
|
||||
Pick<ElasticAssistantRequestHandlerContext, 'elasticAssistant' | 'licensing' | 'core'>
|
||||
|
@ -348,6 +349,7 @@ export const langChainExecute = async ({
|
|||
telemetry,
|
||||
actionTypeId,
|
||||
connectorId,
|
||||
isOssModel,
|
||||
context,
|
||||
actionsClient,
|
||||
inference,
|
||||
|
@ -412,6 +414,7 @@ export const langChainExecute = async ({
|
|||
inference,
|
||||
isStream,
|
||||
llmType: getLlmType(actionTypeId),
|
||||
isOssModel,
|
||||
langChainMessages,
|
||||
logger,
|
||||
onNewReplacements,
|
||||
|
|
|
@ -29,6 +29,7 @@ import {
|
|||
getSystemPromptFromUserConversation,
|
||||
langChainExecute,
|
||||
} from './helpers';
|
||||
import { isOpenSourceModel } from './utils';
|
||||
|
||||
export const postActionsConnectorExecuteRoute = (
|
||||
router: IRouter<ElasticAssistantRequestHandlerContext>,
|
||||
|
@ -94,6 +95,9 @@ export const postActionsConnectorExecuteRoute = (
|
|||
const actions = ctx.elasticAssistant.actions;
|
||||
const inference = ctx.elasticAssistant.inference;
|
||||
const actionsClient = await actions.getActionsClientWithRequest(request);
|
||||
const connectors = await actionsClient.getBulk({ ids: [connectorId] });
|
||||
const connector = connectors.length > 0 ? connectors[0] : undefined;
|
||||
const isOssModel = isOpenSourceModel(connector);
|
||||
|
||||
const conversationsDataClient =
|
||||
await assistantContext.getAIAssistantConversationsDataClient();
|
||||
|
@ -129,6 +133,7 @@ export const postActionsConnectorExecuteRoute = (
|
|||
actionsClient,
|
||||
actionTypeId,
|
||||
connectorId,
|
||||
isOssModel,
|
||||
conversationId,
|
||||
context: ctx,
|
||||
getElser,
|
||||
|
|
69
x-pack/plugins/elastic_assistant/server/routes/utils.test.ts
Normal file
69
x-pack/plugins/elastic_assistant/server/routes/utils.test.ts
Normal file
|
@ -0,0 +1,69 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Connector } from '@kbn/actions-plugin/server/application/connector/types';
|
||||
import { isOpenSourceModel } from './utils';
|
||||
import {
|
||||
OPENAI_CHAT_URL,
|
||||
OpenAiProviderType,
|
||||
} from '@kbn/stack-connectors-plugin/common/openai/constants';
|
||||
|
||||
describe('Utils', () => {
|
||||
describe('isOpenSourceModel', () => {
|
||||
it('should return `false` when connector is undefined', async () => {
|
||||
const isOpenModel = isOpenSourceModel();
|
||||
expect(isOpenModel).toEqual(false);
|
||||
});
|
||||
|
||||
it('should return `false` when connector is a Bedrock', async () => {
|
||||
const connector = { actionTypeId: '.bedrock' } as Connector;
|
||||
const isOpenModel = isOpenSourceModel(connector);
|
||||
expect(isOpenModel).toEqual(false);
|
||||
});
|
||||
|
||||
it('should return `false` when connector is a Gemini', async () => {
|
||||
const connector = { actionTypeId: '.gemini' } as Connector;
|
||||
const isOpenModel = isOpenSourceModel(connector);
|
||||
expect(isOpenModel).toEqual(false);
|
||||
});
|
||||
|
||||
it('should return `false` when connector is a OpenAI and API url is not specified', async () => {
|
||||
const connector = {
|
||||
actionTypeId: '.gen-ai',
|
||||
} as unknown as Connector;
|
||||
const isOpenModel = isOpenSourceModel(connector);
|
||||
expect(isOpenModel).toEqual(false);
|
||||
});
|
||||
|
||||
it('should return `false` when connector is a OpenAI and OpenAI API url is specified', async () => {
|
||||
const connector = {
|
||||
actionTypeId: '.gen-ai',
|
||||
config: { apiUrl: OPENAI_CHAT_URL },
|
||||
} as unknown as Connector;
|
||||
const isOpenModel = isOpenSourceModel(connector);
|
||||
expect(isOpenModel).toEqual(false);
|
||||
});
|
||||
|
||||
it('should return `false` when connector is a AzureOpenAI', async () => {
|
||||
const connector = {
|
||||
actionTypeId: '.gen-ai',
|
||||
config: { apiProvider: OpenAiProviderType.AzureAi },
|
||||
} as unknown as Connector;
|
||||
const isOpenModel = isOpenSourceModel(connector);
|
||||
expect(isOpenModel).toEqual(false);
|
||||
});
|
||||
|
||||
it('should return `true` when connector is a OpenAI and non-OpenAI API url is specified', async () => {
|
||||
const connector = {
|
||||
actionTypeId: '.gen-ai',
|
||||
config: { apiUrl: 'https://elastic.llm.com/llama/chat/completions' },
|
||||
} as unknown as Connector;
|
||||
const isOpenModel = isOpenSourceModel(connector);
|
||||
expect(isOpenModel).toEqual(true);
|
||||
});
|
||||
});
|
||||
});
|
|
@ -19,6 +19,11 @@ import {
|
|||
ActionsClientSimpleChatModel,
|
||||
ActionsClientChatVertexAI,
|
||||
} from '@kbn/langchain/server';
|
||||
import { Connector } from '@kbn/actions-plugin/server/application/connector/types';
|
||||
import {
|
||||
OPENAI_CHAT_URL,
|
||||
OpenAiProviderType,
|
||||
} from '@kbn/stack-connectors-plugin/common/openai/constants';
|
||||
import { CustomHttpRequestError } from './custom_http_request_error';
|
||||
|
||||
export interface OutputError {
|
||||
|
@ -189,3 +194,26 @@ export const getLlmClass = (llmType?: string, bedrockChatEnabled?: boolean) =>
|
|||
: llmType === 'gemini' && bedrockChatEnabled
|
||||
? ActionsClientChatVertexAI
|
||||
: ActionsClientSimpleChatModel;
|
||||
|
||||
export const isOpenSourceModel = (connector?: Connector): boolean => {
|
||||
if (connector == null) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const llmType = getLlmType(connector.actionTypeId);
|
||||
const connectorApiUrl = connector.config?.apiUrl
|
||||
? (connector.config.apiUrl as string)
|
||||
: undefined;
|
||||
const connectorApiProvider = connector.config?.apiProvider
|
||||
? (connector.config?.apiProvider as OpenAiProviderType)
|
||||
: undefined;
|
||||
|
||||
const isOpenAiType = llmType === 'openai';
|
||||
const isOpenAI =
|
||||
isOpenAiType &&
|
||||
(!connectorApiUrl ||
|
||||
connectorApiUrl === OPENAI_CHAT_URL ||
|
||||
connectorApiProvider === OpenAiProviderType.AzureAi);
|
||||
|
||||
return isOpenAiType && !isOpenAI;
|
||||
};
|
||||
|
|
|
@ -244,6 +244,7 @@ export interface AssistantToolParams {
|
|||
kbDataClient?: AIAssistantKnowledgeBaseDataClient;
|
||||
langChainTimeout?: number;
|
||||
llm?: ActionsClientLlm | AssistantToolLlm;
|
||||
isOssModel?: boolean;
|
||||
logger: Logger;
|
||||
modelExists: boolean;
|
||||
onNewReplacements?: (newReplacements: Replacements) => void;
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export const getPromptSuffixForOssModel = (toolName: string) => `
|
||||
When using ${toolName} tool ALWAYS pass the user's questions directly as input into the tool.
|
||||
|
||||
Always return value from ${toolName} tool as is.
|
||||
|
||||
The ES|QL query should ALWAYS be wrapped in triple backticks ("\`\`\`esql"). Add a new line character right before the triple backticks.
|
||||
|
||||
It is important that ES|QL query is preceeded by a new line.`;
|
|
@ -12,6 +12,7 @@ import type { KibanaRequest } from '@kbn/core-http-server';
|
|||
import type { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common/impl/schemas/actions_connector/post_actions_connector_execute_route.gen';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import type { AIAssistantKnowledgeBaseDataClient } from '@kbn/elastic-assistant-plugin/server/ai_assistant_data_clients/knowledge_base';
|
||||
import { getPromptSuffixForOssModel } from './common';
|
||||
|
||||
describe('EsqlLanguageKnowledgeBaseTool', () => {
|
||||
const kbDataClient = jest.fn() as unknown as AIAssistantKnowledgeBaseDataClient;
|
||||
|
@ -108,5 +109,27 @@ describe('EsqlLanguageKnowledgeBaseTool', () => {
|
|||
|
||||
expect(tool.tags).toEqual(['esql', 'query-generation', 'knowledge-base']);
|
||||
});
|
||||
|
||||
it('should return tool with the expected description for OSS model', () => {
|
||||
const tool = ESQL_KNOWLEDGE_BASE_TOOL.getTool({
|
||||
isEnabledKnowledgeBase: true,
|
||||
modelExists: true,
|
||||
isOssModel: true,
|
||||
...rest,
|
||||
}) as DynamicTool;
|
||||
|
||||
expect(tool.description).toContain(getPromptSuffixForOssModel('ESQLKnowledgeBaseTool'));
|
||||
});
|
||||
|
||||
it('should return tool with the expected description for non-OSS model', () => {
|
||||
const tool = ESQL_KNOWLEDGE_BASE_TOOL.getTool({
|
||||
isEnabledKnowledgeBase: true,
|
||||
modelExists: true,
|
||||
isOssModel: false,
|
||||
...rest,
|
||||
}) as DynamicTool;
|
||||
|
||||
expect(tool.description).not.toContain(getPromptSuffixForOssModel('ESQLKnowledgeBaseTool'));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -14,12 +14,15 @@ import { z } from '@kbn/zod';
|
|||
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
|
||||
import { ESQL_RESOURCE } from '@kbn/elastic-assistant-plugin/server/routes/knowledge_base/constants';
|
||||
import { APP_UI_ID } from '../../../../common';
|
||||
import { getPromptSuffixForOssModel } from './common';
|
||||
|
||||
const TOOL_NAME = 'ESQLKnowledgeBaseTool';
|
||||
|
||||
const toolDetails = {
|
||||
id: 'esql-knowledge-base-tool',
|
||||
name: TOOL_NAME,
|
||||
description:
|
||||
'Call this for knowledge on how to build an ESQL query, or answer questions about the ES|QL query language. Input must always be the user query on a single line, with no other text. Your answer will be parsed as JSON, so never use quotes within the output and instead use backticks. Do not add any additional text to describe your output.',
|
||||
id: 'esql-knowledge-base-tool',
|
||||
name: 'ESQLKnowledgeBaseTool',
|
||||
};
|
||||
export const ESQL_KNOWLEDGE_BASE_TOOL: AssistantTool = {
|
||||
...toolDetails,
|
||||
|
@ -31,12 +34,13 @@ export const ESQL_KNOWLEDGE_BASE_TOOL: AssistantTool = {
|
|||
getTool(params: AssistantToolParams) {
|
||||
if (!this.isSupported(params)) return null;
|
||||
|
||||
const { kbDataClient } = params as AssistantToolParams;
|
||||
const { kbDataClient, isOssModel } = params as AssistantToolParams;
|
||||
if (kbDataClient == null) return null;
|
||||
|
||||
return new DynamicStructuredTool({
|
||||
name: toolDetails.name,
|
||||
description: toolDetails.description,
|
||||
description:
|
||||
toolDetails.description + (isOssModel ? getPromptSuffixForOssModel(TOOL_NAME) : ''),
|
||||
schema: z.object({
|
||||
question: z.string().describe(`The user's exact question about ESQL`),
|
||||
}),
|
||||
|
|
|
@ -0,0 +1,162 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { RetrievalQAChain } from 'langchain/chains';
|
||||
import type { DynamicTool } from '@langchain/core/tools';
|
||||
import { NL_TO_ESQL_TOOL } from './nl_to_esql_tool';
|
||||
import type { ElasticsearchClient } from '@kbn/core-elasticsearch-server';
|
||||
import type { KibanaRequest } from '@kbn/core-http-server';
|
||||
import type { ExecuteConnectorRequestBody } from '@kbn/elastic-assistant-common/impl/schemas/actions_connector/post_actions_connector_execute_route.gen';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import { getPromptSuffixForOssModel } from './common';
|
||||
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
|
||||
|
||||
describe('NaturalLanguageESQLTool', () => {
|
||||
const chain = {} as RetrievalQAChain;
|
||||
const esClient = {
|
||||
search: jest.fn().mockResolvedValue({}),
|
||||
} as unknown as ElasticsearchClient;
|
||||
const request = {
|
||||
body: {
|
||||
isEnabledKnowledgeBase: false,
|
||||
alertsIndexPattern: '.alerts-security.alerts-default',
|
||||
allow: ['@timestamp', 'cloud.availability_zone', 'user.name'],
|
||||
allowReplacement: ['user.name'],
|
||||
replacements: { key: 'value' },
|
||||
size: 20,
|
||||
},
|
||||
} as unknown as KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
|
||||
const logger = loggerMock.create();
|
||||
const inference = {} as InferenceServerStart;
|
||||
const connectorId = 'fake-connector';
|
||||
const rest = {
|
||||
chain,
|
||||
esClient,
|
||||
logger,
|
||||
request,
|
||||
inference,
|
||||
connectorId,
|
||||
};
|
||||
|
||||
describe('isSupported', () => {
|
||||
it('returns false if isEnabledKnowledgeBase is false', () => {
|
||||
const params = {
|
||||
isEnabledKnowledgeBase: false,
|
||||
modelExists: true,
|
||||
...rest,
|
||||
};
|
||||
|
||||
expect(NL_TO_ESQL_TOOL.isSupported(params)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false if modelExists is false (the ELSER model is not installed)', () => {
|
||||
const params = {
|
||||
isEnabledKnowledgeBase: true,
|
||||
modelExists: false, // <-- ELSER model is not installed
|
||||
...rest,
|
||||
};
|
||||
|
||||
expect(NL_TO_ESQL_TOOL.isSupported(params)).toBe(false);
|
||||
});
|
||||
|
||||
it('returns true if isEnabledKnowledgeBase and modelExists are true', () => {
|
||||
const params = {
|
||||
isEnabledKnowledgeBase: true,
|
||||
modelExists: true,
|
||||
...rest,
|
||||
};
|
||||
|
||||
expect(NL_TO_ESQL_TOOL.isSupported(params)).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getTool', () => {
|
||||
it('returns null if isEnabledKnowledgeBase is false', () => {
|
||||
const tool = NL_TO_ESQL_TOOL.getTool({
|
||||
isEnabledKnowledgeBase: false,
|
||||
modelExists: true,
|
||||
...rest,
|
||||
});
|
||||
|
||||
expect(tool).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null if modelExists is false (the ELSER model is not installed)', () => {
|
||||
const tool = NL_TO_ESQL_TOOL.getTool({
|
||||
isEnabledKnowledgeBase: true,
|
||||
modelExists: false, // <-- ELSER model is not installed
|
||||
...rest,
|
||||
});
|
||||
|
||||
expect(tool).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null if inference plugin is not provided', () => {
|
||||
const tool = NL_TO_ESQL_TOOL.getTool({
|
||||
isEnabledKnowledgeBase: true,
|
||||
modelExists: true,
|
||||
...rest,
|
||||
inference: undefined,
|
||||
});
|
||||
|
||||
expect(tool).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null if connectorId is not provided', () => {
|
||||
const tool = NL_TO_ESQL_TOOL.getTool({
|
||||
isEnabledKnowledgeBase: true,
|
||||
modelExists: true,
|
||||
...rest,
|
||||
connectorId: undefined,
|
||||
});
|
||||
|
||||
expect(tool).toBeNull();
|
||||
});
|
||||
|
||||
it('should return a Tool instance if isEnabledKnowledgeBase and modelExists are true', () => {
|
||||
const tool = NL_TO_ESQL_TOOL.getTool({
|
||||
isEnabledKnowledgeBase: true,
|
||||
modelExists: true,
|
||||
...rest,
|
||||
});
|
||||
|
||||
expect(tool?.name).toEqual('NaturalLanguageESQLTool');
|
||||
});
|
||||
|
||||
it('should return a tool with the expected tags', () => {
|
||||
const tool = NL_TO_ESQL_TOOL.getTool({
|
||||
isEnabledKnowledgeBase: true,
|
||||
modelExists: true,
|
||||
...rest,
|
||||
}) as DynamicTool;
|
||||
|
||||
expect(tool.tags).toEqual(['esql', 'query-generation', 'knowledge-base']);
|
||||
});
|
||||
|
||||
it('should return tool with the expected description for OSS model', () => {
|
||||
const tool = NL_TO_ESQL_TOOL.getTool({
|
||||
isEnabledKnowledgeBase: true,
|
||||
modelExists: true,
|
||||
isOssModel: true,
|
||||
...rest,
|
||||
}) as DynamicTool;
|
||||
|
||||
expect(tool.description).toContain(getPromptSuffixForOssModel('NaturalLanguageESQLTool'));
|
||||
});
|
||||
|
||||
it('should return tool with the expected description for non-OSS model', () => {
|
||||
const tool = NL_TO_ESQL_TOOL.getTool({
|
||||
isEnabledKnowledgeBase: true,
|
||||
modelExists: true,
|
||||
isOssModel: false,
|
||||
...rest,
|
||||
}) as DynamicTool;
|
||||
|
||||
expect(tool.description).not.toContain(getPromptSuffixForOssModel('NaturalLanguageESQLTool'));
|
||||
});
|
||||
});
|
||||
});
|
|
@ -11,6 +11,7 @@ import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-
|
|||
import { lastValueFrom } from 'rxjs';
|
||||
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
|
||||
import { APP_UI_ID } from '../../../../common';
|
||||
import { getPromptSuffixForOssModel } from './common';
|
||||
|
||||
export type ESQLToolParams = AssistantToolParams;
|
||||
|
||||
|
@ -37,7 +38,7 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
|
|||
getTool(params: ESQLToolParams) {
|
||||
if (!this.isSupported(params)) return null;
|
||||
|
||||
const { connectorId, inference, logger, request } = params as ESQLToolParams;
|
||||
const { connectorId, inference, logger, request, isOssModel } = params as ESQLToolParams;
|
||||
if (inference == null || connectorId == null) return null;
|
||||
|
||||
const callNaturalLanguageToEsql = async (question: string) => {
|
||||
|
@ -46,6 +47,7 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
|
|||
client: inference.getClient({ request }),
|
||||
connectorId,
|
||||
input: question,
|
||||
...(isOssModel ? { functionCalling: 'simulated' } : {}),
|
||||
logger: {
|
||||
debug: (source) => {
|
||||
logger.debug(typeof source === 'function' ? source() : source);
|
||||
|
@ -57,7 +59,8 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
|
|||
|
||||
return new DynamicStructuredTool({
|
||||
name: toolDetails.name,
|
||||
description: toolDetails.description,
|
||||
description:
|
||||
toolDetails.description + (isOssModel ? getPromptSuffixForOssModel(TOOL_NAME) : ''),
|
||||
schema: z.object({
|
||||
question: z.string().describe(`The user's exact question about ESQL`),
|
||||
}),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue