mirror of
https://github.com/elastic/kibana.git
synced 2025-06-27 10:40:07 -04:00
[Security Solution] [AI assistant ] Fix error where llm.bindTools is not a function. (#225268)
## Summary Summarize your PR. If it involves visual changes, include a screenshot or gif. This PR fixes a bug where the error message "llm.bindTools is not a function" would appear in the Security AI assistant.  Changes: - Make AssistantTool.getTool return a promise. This means tools can be created asynchronously. This eliminates the error, as the error stems from the promise `createLlmInstance()` ([ref](https://github.com/elastic/kibana/pull/225268/files#diff-69e7fc6c29ce0673d7d33c0472a012ad310fa571487a6b594d2e1378b3e5f246R286)) not being awaited. - Improve type definition for tools so that we avoid bugs when the AssistantTool type changes e.g. https://github.com/elastic/kibana/pull/225268/files#diff-b603523fee68a791bd3af770b780fc654eb7866c8d2a73192d29fa935c80e541R17 ### How to test: - Enable AdvancedEsqlGeneration feature flag: ```yml # kibana.dev.yml xpack.securitySolution.enableExperimental: ['advancedEsqlGeneration'] ``` - Start Kibana - Open the Security AI assistant - Ask a question - Expect to see a response from the LLM. ### Checklist Check the PR satisfies following conditions. Reviewers should verify this PR satisfies this list as well. - [x] Any text added follows [EUI's writing guidelines](https://elastic.github.io/eui/#/guidelines/writing), uses sentence case text and includes [i18n support](https://github.com/elastic/kibana/blob/main/src/platform/packages/shared/kbn-i18n/README.md) - [x] [Documentation](https://www.elastic.co/guide/en/kibana/master/development-documentation.html) was added for features that require explanation or tutorials - [x] [Unit or functional tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html) were updated or added to match the most common scenarios - [x] If a plugin configuration key changed, check if it needs to be allowlisted in the cloud and added to the [docker list](https://github.com/elastic/kibana/blob/main/src/dev/build/tasks/os_packages/docker_generator/resources/base/bin/kibana-docker) - [x] This was checked for breaking HTTP API changes, and any breaking changes have been approved by the breaking-change committee. The `release_note:breaking` label should be applied in these situations. - [x] [Flaky Test Runner](https://ci-stats.kibana.dev/trigger_flaky_test_runner/1) was used on any tests changed - [x] The PR description includes the appropriate Release Notes section, and the correct `release_note:*` label is applied per the [guidelines](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process) ### Identify risks Does this PR introduce any risks? For example, consider risks like hard to test bugs, performance regression, potential of data loss. Describe the risk, its severity, and mitigation for each identified risk. Invite stakeholders and evaluate how to proceed before merging. - [ ] [See some risk examples](https://github.com/elastic/kibana/blob/main/RISK_MATRIX.mdx) - [ ] ... --------- Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
parent
4937b3a849
commit
dc24f2068b
29 changed files with 167 additions and 182 deletions
|
@ -248,7 +248,7 @@ export interface AssistantTool {
|
|||
description: string;
|
||||
sourceRegister: string;
|
||||
isSupported: (params: AssistantToolParams) => boolean;
|
||||
getTool: (params: AssistantToolParams) => StructuredToolInterface | null;
|
||||
getTool: (params: AssistantToolParams) => Promise<StructuredToolInterface | null>;
|
||||
}
|
||||
|
||||
export type AssistantToolLlm =
|
||||
|
@ -285,3 +285,14 @@ export interface AssistantToolParams {
|
|||
telemetry?: AnalyticsServiceSetup;
|
||||
createLlmInstance?: () => Promise<AssistantToolLlm>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper type for working with AssistantToolParams when some properties are required.
|
||||
*
|
||||
*
|
||||
* ```ts
|
||||
* export type MyNewTypeWithAssistantContext = Require<AssistantToolParams, 'assistantContext'>
|
||||
* ```
|
||||
*/
|
||||
|
||||
export type Require<T extends object, P extends keyof T> = Omit<T, P> & Required<Pick<T, P>>;
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
|
||||
import type {
|
||||
ActionsClientChatOpenAI,
|
||||
ActionsClientChatVertexAI,
|
||||
ActionsClientSimpleChatModel,
|
||||
} from '@kbn/langchain/server/language_models';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
|
@ -17,12 +16,12 @@ import fs from 'fs/promises';
|
|||
import path from 'path';
|
||||
import type { ElasticsearchClient, KibanaRequest } from '@kbn/core/server';
|
||||
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
|
||||
import type { ActionsClientChatBedrockConverse } from '@kbn/langchain/server';
|
||||
import { getGenerateEsqlGraph as getGenerateEsqlAgent } from '../../server/assistant/tools/esql/graphs/generate_esql/generate_esql';
|
||||
import { getRuleMigrationAgent } from '../../server/lib/siem_migrations/rules/task/agent';
|
||||
import type { RuleMigrationsRetriever } from '../../server/lib/siem_migrations/rules/task/retrievers';
|
||||
import type { EsqlKnowledgeBase } from '../../server/lib/siem_migrations/rules/task/util/esql_knowledge_base';
|
||||
import type { SiemMigrationTelemetryClient } from '../../server/lib/siem_migrations/rules/task/rule_migrations_telemetry_client';
|
||||
import type { CreateLlmInstance } from '../../server/assistant/tools/esql/utils/common';
|
||||
|
||||
interface Drawable {
|
||||
drawMermaidPng: () => Promise<Blob>;
|
||||
|
@ -53,17 +52,13 @@ async function getSiemMigrationGraph(logger: Logger): Promise<Drawable> {
|
|||
}
|
||||
|
||||
async function getGenerateEsqlGraph(logger: Logger): Promise<Drawable> {
|
||||
const graph = getGenerateEsqlAgent({
|
||||
const graph = await getGenerateEsqlAgent({
|
||||
esClient: {} as unknown as ElasticsearchClient,
|
||||
connectorId: 'test-connector-id',
|
||||
inference: {} as unknown as InferenceServerStart,
|
||||
logger,
|
||||
request: {} as unknown as KibanaRequest,
|
||||
createLlmInstance: () =>
|
||||
({ bindTools: () => null } as unknown as
|
||||
| ActionsClientChatBedrockConverse
|
||||
| ActionsClientChatVertexAI
|
||||
| ActionsClientChatOpenAI),
|
||||
createLlmInstance: (() => ({ bindTools: () => null })) as unknown as CreateLlmInstance,
|
||||
});
|
||||
return graph.getGraphAsync({ xray: true });
|
||||
}
|
||||
|
|
|
@ -86,13 +86,13 @@ describe('AlertCountsTool', () => {
|
|||
|
||||
describe('getTool', () => {
|
||||
it('returns a `DynamicTool` with a `func` that calls `esClient.search()` with the expected query', async () => {
|
||||
const tool: DynamicTool = ALERT_COUNTS_TOOL.getTool({
|
||||
const tool: DynamicTool = (await ALERT_COUNTS_TOOL.getTool({
|
||||
alertsIndexPattern,
|
||||
esClient,
|
||||
replacements,
|
||||
request,
|
||||
...rest,
|
||||
}) as DynamicTool;
|
||||
})) as DynamicTool;
|
||||
|
||||
await tool.func('');
|
||||
|
||||
|
@ -163,13 +163,13 @@ describe('AlertCountsTool', () => {
|
|||
});
|
||||
|
||||
it('includes citations', async () => {
|
||||
const tool: DynamicTool = ALERT_COUNTS_TOOL.getTool({
|
||||
const tool: DynamicTool = (await ALERT_COUNTS_TOOL.getTool({
|
||||
alertsIndexPattern,
|
||||
esClient,
|
||||
replacements,
|
||||
request,
|
||||
...rest,
|
||||
}) as DynamicTool;
|
||||
})) as DynamicTool;
|
||||
|
||||
(contentReferencesStore.add as jest.Mock).mockImplementation(
|
||||
(creator: Parameters<ContentReferencesStore['add']>[0]) => {
|
||||
|
@ -184,8 +184,8 @@ describe('AlertCountsTool', () => {
|
|||
expect(result).toContain('Citation: {reference(exampleContentReferenceId)}');
|
||||
});
|
||||
|
||||
it('returns null when the alertsIndexPattern is undefined', () => {
|
||||
const tool = ALERT_COUNTS_TOOL.getTool({
|
||||
it('returns null when the alertsIndexPattern is undefined', async () => {
|
||||
const tool = await ALERT_COUNTS_TOOL.getTool({
|
||||
// alertsIndexPattern is undefined
|
||||
esClient,
|
||||
replacements,
|
||||
|
@ -197,15 +197,15 @@ describe('AlertCountsTool', () => {
|
|||
expect(tool).toBeNull();
|
||||
});
|
||||
|
||||
it('returns a tool instance with the expected tags', () => {
|
||||
const tool = ALERT_COUNTS_TOOL.getTool({
|
||||
it('returns a tool instance with the expected tags', async () => {
|
||||
const tool = (await ALERT_COUNTS_TOOL.getTool({
|
||||
alertsIndexPattern,
|
||||
esClient,
|
||||
replacements,
|
||||
request,
|
||||
|
||||
...rest,
|
||||
}) as DynamicTool;
|
||||
})) as DynamicTool;
|
||||
|
||||
expect(tool.tags).toEqual(['alerts', 'alerts-count']);
|
||||
});
|
||||
|
|
|
@ -10,12 +10,12 @@ import { tool } from '@langchain/core/tools';
|
|||
import { requestHasRequiredAnonymizationParams } from '@kbn/elastic-assistant-plugin/server/lib/langchain/helpers';
|
||||
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
|
||||
import { contentReferenceString, securityAlertsPageReference } from '@kbn/elastic-assistant-common';
|
||||
import type { Require } from '@kbn/elastic-assistant-plugin/server/types';
|
||||
import { getAlertsCountQuery } from './get_alert_counts_query';
|
||||
import { APP_UI_ID } from '../../../../common';
|
||||
|
||||
export interface AlertCountsToolParams extends AssistantToolParams {
|
||||
alertsIndexPattern: string;
|
||||
}
|
||||
export type AlertCountsToolParams = Require<AssistantToolParams, 'alertsIndexPattern'>;
|
||||
|
||||
export const ALERT_COUNTS_TOOL_DESCRIPTION =
|
||||
'Call this for the counts of last 24 hours of open and acknowledged alerts in the environment, grouped by their severity and workflow status. The response will be JSON and from it you can summarize the information to answer the question.';
|
||||
|
||||
|
@ -31,7 +31,7 @@ export const ALERT_COUNTS_TOOL: AssistantTool = {
|
|||
const { request, alertsIndexPattern } = params;
|
||||
return requestHasRequiredAnonymizationParams(request) && alertsIndexPattern != null;
|
||||
},
|
||||
getTool(params: AssistantToolParams) {
|
||||
async getTool(params: AssistantToolParams) {
|
||||
if (!this.isSupported(params)) return null;
|
||||
const { alertsIndexPattern, esClient, contentReferencesStore } =
|
||||
params as AlertCountsToolParams;
|
||||
|
|
|
@ -10,13 +10,11 @@ import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-
|
|||
import { z } from '@kbn/zod';
|
||||
import { lastValueFrom } from 'rxjs';
|
||||
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
|
||||
import type { ElasticAssistantApiRequestHandlerContext } from '@kbn/elastic-assistant-plugin/server/types';
|
||||
import type { Require } from '@kbn/elastic-assistant-plugin/server/types';
|
||||
import { APP_UI_ID } from '../../../../common';
|
||||
import { getPromptSuffixForOssModel } from './utils/common';
|
||||
|
||||
export type ESQLToolParams = AssistantToolParams & {
|
||||
assistantContext: ElasticAssistantApiRequestHandlerContext;
|
||||
};
|
||||
export type ESQLToolParams = Require<AssistantToolParams, 'assistantContext'>;
|
||||
|
||||
const TOOL_NAME = 'AskAboutEsqlTool';
|
||||
|
||||
|
@ -49,7 +47,7 @@ export const ASK_ABOUT_ESQL_TOOL: AssistantTool = {
|
|||
assistantContext.getRegisteredFeatures('securitySolutionUI').advancedEsqlGeneration
|
||||
);
|
||||
},
|
||||
getTool(params: AssistantToolParams) {
|
||||
async getTool(params: AssistantToolParams) {
|
||||
if (!this.isSupported(params)) return null;
|
||||
|
||||
const { connectorId, inference, logger, request, isOssModel } = params as ESQLToolParams;
|
||||
|
|
|
@ -9,23 +9,15 @@ import { tool } from '@langchain/core/tools';
|
|||
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
|
||||
import { z } from '@kbn/zod';
|
||||
import { HumanMessage } from '@langchain/core/messages';
|
||||
import type { ElasticAssistantApiRequestHandlerContext } from '@kbn/elastic-assistant-plugin/server/types';
|
||||
import type {
|
||||
ActionsClientChatBedrockConverse,
|
||||
ActionsClientChatVertexAI,
|
||||
ActionsClientChatOpenAI,
|
||||
} from '@kbn/langchain/server';
|
||||
import type { Require } from '@kbn/elastic-assistant-plugin/server/types';
|
||||
import { APP_UI_ID } from '../../../../common';
|
||||
import { getPromptSuffixForOssModel } from './utils/common';
|
||||
import { getGenerateEsqlGraph } from './graphs/generate_esql/generate_esql';
|
||||
|
||||
export type GenerateEsqlParams = AssistantToolParams & {
|
||||
assistantContext: ElasticAssistantApiRequestHandlerContext;
|
||||
createLlmInstance: () =>
|
||||
| ActionsClientChatBedrockConverse
|
||||
| ActionsClientChatVertexAI
|
||||
| ActionsClientChatOpenAI;
|
||||
};
|
||||
export type GenerateEsqlParams = Require<
|
||||
AssistantToolParams,
|
||||
'assistantContext' | 'createLlmInstance'
|
||||
>;
|
||||
|
||||
const TOOL_NAME = 'GenerateESQLTool';
|
||||
|
||||
|
@ -55,7 +47,7 @@ export const GENERATE_ESQL_TOOL: AssistantTool = {
|
|||
createLlmInstance != null
|
||||
);
|
||||
},
|
||||
getTool(params: AssistantToolParams) {
|
||||
async getTool(params: AssistantToolParams) {
|
||||
if (!this.isSupported(params)) return null;
|
||||
|
||||
const { connectorId, inference, logger, request, isOssModel, esClient, createLlmInstance } =
|
||||
|
@ -63,7 +55,7 @@ export const GENERATE_ESQL_TOOL: AssistantTool = {
|
|||
|
||||
if (inference == null || connectorId == null) return null;
|
||||
|
||||
const selfHealingGraph = getGenerateEsqlGraph({
|
||||
const selfHealingGraph = await getGenerateEsqlGraph({
|
||||
esClient,
|
||||
connectorId,
|
||||
inference,
|
||||
|
|
|
@ -24,24 +24,28 @@ import { getAnalyzeCompressedIndexMappingAgent } from './nodes/analyze_compresse
|
|||
import { getExplorePartialIndexMappingAgent } from './nodes/explore_partial_index_mapping_agent/explore_partial_index_mapping_agent';
|
||||
import { getExplorePartialIndexMappingResponder } from './nodes/explore_partial_index_mapping_responder/explore_partial_index_mapping_responder';
|
||||
|
||||
export const getAnalyzeIndexPatternGraph = ({
|
||||
export const getAnalyzeIndexPatternGraph = async ({
|
||||
esClient,
|
||||
createLlmInstance,
|
||||
}: {
|
||||
esClient: ElasticsearchClient;
|
||||
createLlmInstance: CreateLlmInstance;
|
||||
}) => {
|
||||
const [
|
||||
analyzeCompressedIndexMappingAgent,
|
||||
explorePartialIndexMappingAgent,
|
||||
explorePartialIndexMappingResponder,
|
||||
] = await Promise.all([
|
||||
getAnalyzeCompressedIndexMappingAgent({ createLlmInstance }),
|
||||
getExplorePartialIndexMappingAgent({ esClient, createLlmInstance }),
|
||||
getExplorePartialIndexMappingResponder({ createLlmInstance }),
|
||||
]);
|
||||
|
||||
const graph = new StateGraph(AnalyzeIndexPatternAnnotation)
|
||||
.addNode(GET_FIELD_DESCRIPTORS, getFieldDescriptors({ esClient }))
|
||||
.addNode(
|
||||
ANALYZE_COMPRESSED_INDEX_MAPPING_AGENT,
|
||||
getAnalyzeCompressedIndexMappingAgent({ createLlmInstance })
|
||||
)
|
||||
.addNode(ANALYZE_COMPRESSED_INDEX_MAPPING_AGENT, analyzeCompressedIndexMappingAgent)
|
||||
|
||||
.addNode(
|
||||
EXPLORE_PARTIAL_INDEX_AGENT,
|
||||
getExplorePartialIndexMappingAgent({ esClient, createLlmInstance })
|
||||
)
|
||||
.addNode(EXPLORE_PARTIAL_INDEX_AGENT, explorePartialIndexMappingAgent)
|
||||
.addNode(TOOLS, (state: typeof AnalyzeIndexPatternAnnotation.State) => {
|
||||
const { input } = state;
|
||||
if (input === undefined) {
|
||||
|
@ -55,10 +59,7 @@ export const getAnalyzeIndexPatternGraph = ({
|
|||
const toolNode = new ToolNode(tools);
|
||||
return toolNode.invoke(state);
|
||||
})
|
||||
.addNode(
|
||||
EXPLORE_PARTIAL_INDEX_RESPONDER,
|
||||
getExplorePartialIndexMappingResponder({ createLlmInstance })
|
||||
)
|
||||
.addNode(EXPLORE_PARTIAL_INDEX_RESPONDER, explorePartialIndexMappingResponder)
|
||||
|
||||
.addEdge(START, GET_FIELD_DESCRIPTORS)
|
||||
.addConditionalEdges(
|
||||
|
|
|
@ -19,12 +19,12 @@ const structuredOutput = z.object({
|
|||
.describe('Whether the index pattern contains the required fields for the query'),
|
||||
});
|
||||
|
||||
export const getAnalyzeCompressedIndexMappingAgent = ({
|
||||
export const getAnalyzeCompressedIndexMappingAgent = async ({
|
||||
createLlmInstance,
|
||||
}: {
|
||||
createLlmInstance: CreateLlmInstance;
|
||||
}) => {
|
||||
const llm = createLlmInstance();
|
||||
const llm = await createLlmInstance();
|
||||
return async (state: typeof AnalyzeIndexPatternAnnotation.State) => {
|
||||
const { fieldDescriptors, input } = state;
|
||||
if (fieldDescriptors === undefined) {
|
||||
|
|
|
@ -12,18 +12,19 @@ import type { CreateLlmInstance } from '../../../../utils/common';
|
|||
import type { AnalyzeIndexPatternAnnotation } from '../../state';
|
||||
import { getInspectIndexMappingTool } from '../../../../tools/inspect_index_mapping_tool/inspect_index_mapping_tool';
|
||||
|
||||
export const getExplorePartialIndexMappingAgent = ({
|
||||
export const getExplorePartialIndexMappingAgent = async ({
|
||||
createLlmInstance,
|
||||
esClient,
|
||||
}: {
|
||||
createLlmInstance: CreateLlmInstance;
|
||||
esClient: ElasticsearchClient;
|
||||
}) => {
|
||||
const llm = createLlmInstance();
|
||||
const llm = await createLlmInstance();
|
||||
const tool = getInspectIndexMappingTool({
|
||||
esClient,
|
||||
indexPattern: 'placeholder',
|
||||
});
|
||||
|
||||
const llmWithTools = llm.bindTools([tool]);
|
||||
|
||||
return async (state: typeof AnalyzeIndexPatternAnnotation.State) => {
|
||||
|
|
|
@ -18,12 +18,12 @@ const structuredOutput = z.object({
|
|||
.describe('Whether the index pattern contains the required fields for the query'),
|
||||
});
|
||||
|
||||
export const getExplorePartialIndexMappingResponder = ({
|
||||
export const getExplorePartialIndexMappingResponder = async ({
|
||||
createLlmInstance,
|
||||
}: {
|
||||
createLlmInstance: CreateLlmInstance;
|
||||
}) => {
|
||||
const llm = createLlmInstance();
|
||||
const llm = await createLlmInstance();
|
||||
return async (state: typeof AnalyzeIndexPatternAnnotation.State) => {
|
||||
const { messages } = state;
|
||||
|
||||
|
|
|
@ -8,11 +8,6 @@
|
|||
import type { ElasticsearchClient, KibanaRequest, Logger } from '@kbn/core/server';
|
||||
import { END, START, StateGraph } from '@langchain/langgraph';
|
||||
import type { InferenceServerStart } from '@kbn/inference-plugin/server';
|
||||
import type {
|
||||
ActionsClientChatBedrockConverse,
|
||||
ActionsClientChatVertexAI,
|
||||
ActionsClientChatOpenAI,
|
||||
} from '@kbn/langchain/server';
|
||||
import { ToolNode } from '@langchain/langgraph/prebuilt';
|
||||
import { GenerateEsqlAnnotation } from './state';
|
||||
|
||||
|
@ -41,8 +36,9 @@ import { getBuildUnvalidatedReportFromLastMessageNode } from './nodes/build_unva
|
|||
|
||||
import { getSelectIndexPattern } from './nodes/select_index_pattern/select_index_pattern';
|
||||
import { getSelectIndexPatternGraph } from '../select_index_pattern/select_index_pattern';
|
||||
import type { CreateLlmInstance } from '../../utils/common';
|
||||
|
||||
export const getGenerateEsqlGraph = ({
|
||||
export const getGenerateEsqlGraph = async ({
|
||||
esClient,
|
||||
connectorId,
|
||||
inference,
|
||||
|
@ -55,10 +51,7 @@ export const getGenerateEsqlGraph = ({
|
|||
inference: InferenceServerStart;
|
||||
logger: Logger;
|
||||
request: KibanaRequest;
|
||||
createLlmInstance: () =>
|
||||
| ActionsClientChatBedrockConverse
|
||||
| ActionsClientChatVertexAI
|
||||
| ActionsClientChatOpenAI;
|
||||
createLlmInstance: CreateLlmInstance;
|
||||
}) => {
|
||||
const nlToEsqlAgentNode = getNlToEsqlAgent({
|
||||
connectorId,
|
||||
|
@ -90,7 +83,7 @@ export const getGenerateEsqlGraph = ({
|
|||
|
||||
const buildUnvalidatedReportFromLastMessageNode = getBuildUnvalidatedReportFromLastMessageNode();
|
||||
|
||||
const identifyIndexGraph = getSelectIndexPatternGraph({
|
||||
const identifyIndexGraph = await getSelectIndexPatternGraph({
|
||||
esClient,
|
||||
createLlmInstance,
|
||||
});
|
||||
|
|
|
@ -14,7 +14,7 @@ import { NL_TO_ESQL_AGENT_WITHOUT_VALIDATION_NODE } from '../../constants';
|
|||
export const getSelectIndexPattern = ({
|
||||
identifyIndexGraph,
|
||||
}: {
|
||||
identifyIndexGraph: ReturnType<typeof getSelectIndexPatternGraph>;
|
||||
identifyIndexGraph: Awaited<ReturnType<typeof getSelectIndexPatternGraph>>;
|
||||
}) => {
|
||||
return async (state: typeof GenerateEsqlAnnotation.State) => {
|
||||
const childGraphOutput = await identifyIndexGraph.invoke({
|
||||
|
|
|
@ -11,7 +11,7 @@ import type { getAnalyzeIndexPatternGraph } from '../../../analyse_index_pattern
|
|||
export const getAnalyzeIndexPattern = ({
|
||||
analyzeIndexPatternGraph,
|
||||
}: {
|
||||
analyzeIndexPatternGraph: ReturnType<typeof getAnalyzeIndexPatternGraph>;
|
||||
analyzeIndexPatternGraph: Awaited<ReturnType<typeof getAnalyzeIndexPatternGraph>>;
|
||||
}) => {
|
||||
return async ({
|
||||
input,
|
||||
|
|
|
@ -5,23 +5,11 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type {
|
||||
ActionsClientChatBedrockConverse,
|
||||
ActionsClientChatVertexAI,
|
||||
ActionsClientChatOpenAI,
|
||||
} from '@kbn/langchain/server';
|
||||
import { Command } from '@langchain/langgraph';
|
||||
|
||||
import type { SelectIndexPatternAnnotation } from '../../state';
|
||||
|
||||
export const getSelectIndexPattern = ({
|
||||
createLlmInstance,
|
||||
}: {
|
||||
createLlmInstance: () =>
|
||||
| ActionsClientChatBedrockConverse
|
||||
| ActionsClientChatVertexAI
|
||||
| ActionsClientChatOpenAI;
|
||||
}) => {
|
||||
export const getSelectIndexPattern = () => {
|
||||
return async (state: typeof SelectIndexPatternAnnotation.State) => {
|
||||
const indexPatternAnalysis = Object.values(state.indexPatternAnalysis);
|
||||
const candidateIndexPatterns = indexPatternAnalysis.filter(
|
||||
|
|
|
@ -5,15 +5,11 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type {
|
||||
ActionsClientChatBedrockConverse,
|
||||
ActionsClientChatVertexAI,
|
||||
ActionsClientChatOpenAI,
|
||||
} from '@kbn/langchain/server';
|
||||
import { HumanMessage, SystemMessage } from '@langchain/core/messages';
|
||||
import { Command } from '@langchain/langgraph';
|
||||
import { z } from '@kbn/zod';
|
||||
import type { SelectIndexPatternAnnotation } from '../../state';
|
||||
import type { CreateLlmInstance } from '../../../../utils/common';
|
||||
|
||||
const ShortlistedIndexPatterns = z
|
||||
.object({
|
||||
|
@ -23,15 +19,12 @@ const ShortlistedIndexPatterns = z
|
|||
'Object containing array of shortlisted index patterns that might be used to generate the query'
|
||||
);
|
||||
|
||||
export const getShortlistIndexPatterns = ({
|
||||
export const getShortlistIndexPatterns = async ({
|
||||
createLlmInstance,
|
||||
}: {
|
||||
createLlmInstance: () =>
|
||||
| ActionsClientChatBedrockConverse
|
||||
| ActionsClientChatVertexAI
|
||||
| ActionsClientChatOpenAI;
|
||||
createLlmInstance: CreateLlmInstance;
|
||||
}) => {
|
||||
const llm = createLlmInstance();
|
||||
const llm = await createLlmInstance();
|
||||
|
||||
return async (state: typeof SelectIndexPatternAnnotation.State) => {
|
||||
const systemMessage = new SystemMessage({
|
||||
|
|
|
@ -6,11 +6,6 @@
|
|||
*/
|
||||
|
||||
import { START, StateGraph, Send, END } from '@langchain/langgraph';
|
||||
import type {
|
||||
ActionsClientChatBedrockConverse,
|
||||
ActionsClientChatVertexAI,
|
||||
ActionsClientChatOpenAI,
|
||||
} from '@kbn/langchain/server';
|
||||
import type { ElasticsearchClient } from '@kbn/core/server';
|
||||
import { SelectIndexPatternAnnotation } from './state';
|
||||
import {
|
||||
|
@ -24,27 +19,30 @@ import { getShortlistIndexPatterns } from './nodes/shortlist_index_patterns/shor
|
|||
import { getAnalyzeIndexPattern } from './nodes/analyse_index_pattern/analyse_index_pattern';
|
||||
import { getSelectIndexPattern } from './nodes/select_index/select_index';
|
||||
import { getAnalyzeIndexPatternGraph } from '../analyse_index_pattern/analyse_index_pattern';
|
||||
import type { CreateLlmInstance } from '../../utils/common';
|
||||
|
||||
export const getSelectIndexPatternGraph = ({
|
||||
export const getSelectIndexPatternGraph = async ({
|
||||
createLlmInstance,
|
||||
esClient,
|
||||
}: {
|
||||
createLlmInstance: () =>
|
||||
| ActionsClientChatBedrockConverse
|
||||
| ActionsClientChatVertexAI
|
||||
| ActionsClientChatOpenAI;
|
||||
createLlmInstance: CreateLlmInstance;
|
||||
esClient: ElasticsearchClient;
|
||||
}) => {
|
||||
const analyzeIndexPatternGraph = getAnalyzeIndexPatternGraph({
|
||||
esClient,
|
||||
createLlmInstance,
|
||||
});
|
||||
const [analyzeIndexPatternGraph, shortlistIndexPatterns] = await Promise.all([
|
||||
getAnalyzeIndexPatternGraph({
|
||||
esClient,
|
||||
createLlmInstance,
|
||||
}),
|
||||
getShortlistIndexPatterns({
|
||||
createLlmInstance,
|
||||
}),
|
||||
]);
|
||||
|
||||
const graph = new StateGraph(SelectIndexPatternAnnotation)
|
||||
.addNode(GET_INDEX_PATTERNS, fetchIndexPatterns({ esClient }), {
|
||||
retryPolicy: { maxAttempts: 3 },
|
||||
})
|
||||
.addNode(SHORTLIST_INDEX_PATTERNS, getShortlistIndexPatterns({ createLlmInstance }))
|
||||
.addNode(SHORTLIST_INDEX_PATTERNS, shortlistIndexPatterns)
|
||||
.addNode(
|
||||
ANALYZE_INDEX_PATTERN,
|
||||
getAnalyzeIndexPattern({
|
||||
|
@ -52,7 +50,7 @@ export const getSelectIndexPatternGraph = ({
|
|||
}),
|
||||
{ retryPolicy: { maxAttempts: 3 }, subgraphs: [analyzeIndexPatternGraph] }
|
||||
)
|
||||
.addNode(SELECT_INDEX_PATTERN, getSelectIndexPattern({ createLlmInstance }), {
|
||||
.addNode(SELECT_INDEX_PATTERN, getSelectIndexPattern(), {
|
||||
retryPolicy: { maxAttempts: 3 },
|
||||
})
|
||||
|
||||
|
|
|
@ -10,14 +10,13 @@ import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-
|
|||
import { lastValueFrom } from 'rxjs';
|
||||
import { naturalLanguageToEsql } from '@kbn/inference-plugin/server';
|
||||
import { z } from '@kbn/zod';
|
||||
import type { ElasticAssistantApiRequestHandlerContext } from '@kbn/elastic-assistant-plugin/server/types';
|
||||
import type { Require } from '@kbn/elastic-assistant-plugin/server/types';
|
||||
import { APP_UI_ID } from '../../../../common';
|
||||
import { getPromptSuffixForOssModel } from './utils/common';
|
||||
|
||||
// select only some properties of AssistantToolParams
|
||||
export type ESQLToolParams = AssistantToolParams & {
|
||||
assistantContext: ElasticAssistantApiRequestHandlerContext;
|
||||
};
|
||||
|
||||
export type ESQLToolParams = Require<AssistantToolParams, 'assistantContext'>;
|
||||
|
||||
const TOOL_NAME = 'NaturalLanguageESQLTool';
|
||||
|
||||
|
@ -47,7 +46,7 @@ export const NL_TO_ESQL_TOOL: AssistantTool = {
|
|||
!assistantContext.getRegisteredFeatures('securitySolutionUI').advancedEsqlGeneration
|
||||
);
|
||||
},
|
||||
getTool(params: AssistantToolParams) {
|
||||
async getTool(params: AssistantToolParams) {
|
||||
if (!this.isSupported(params)) return null;
|
||||
|
||||
const { connectorId, inference, logger, request, isOssModel } = params as ESQLToolParams;
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
|
||||
|
||||
export type CreateLlmInstance = Exclude<AssistantToolParams['createLlmInstance'], undefined>;
|
|
@ -5,14 +5,10 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type {
|
||||
ActionsClientChatBedrockConverse,
|
||||
ActionsClientChatVertexAI,
|
||||
ActionsClientChatOpenAI,
|
||||
} from '@kbn/langchain/server';
|
||||
import type { BaseMessage } from '@langchain/core/messages';
|
||||
import { AIMessage } from '@langchain/core/messages';
|
||||
import type { ToolCall } from '@langchain/core/dist/messages/tool';
|
||||
import type { AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
|
||||
import { toolDetails } from '../tools/inspect_index_mapping_tool/inspect_index_mapping_tool';
|
||||
|
||||
export const getPromptSuffixForOssModel = (toolName: string) => `
|
||||
|
@ -30,10 +26,7 @@ export const messageContainsToolCalls = (message: BaseMessage): message is AIMes
|
|||
);
|
||||
};
|
||||
|
||||
export type CreateLlmInstance = () =>
|
||||
| ActionsClientChatBedrockConverse
|
||||
| ActionsClientChatVertexAI
|
||||
| ActionsClientChatOpenAI;
|
||||
export type CreateLlmInstance = Exclude<AssistantToolParams['createLlmInstance'], undefined>;
|
||||
|
||||
export const requireFirstInspectIndexMappingCallWithEmptyKey = (
|
||||
newMessage: AIMessage,
|
||||
|
|
|
@ -34,7 +34,9 @@ describe('KnowledgeBaseRetievalTool', () => {
|
|||
|
||||
describe('DynamicStructuredTool', () => {
|
||||
it('includes citations', async () => {
|
||||
const tool = KNOWLEDGE_BASE_RETRIEVAL_TOOL.getTool(defaultArgs) as DynamicStructuredTool;
|
||||
const tool = (await KNOWLEDGE_BASE_RETRIEVAL_TOOL.getTool(
|
||||
defaultArgs
|
||||
)) as DynamicStructuredTool;
|
||||
|
||||
getKnowledgeBaseDocumentEntries.mockResolvedValue([
|
||||
new Document({
|
||||
|
|
|
@ -8,15 +8,13 @@
|
|||
import { tool } from '@langchain/core/tools';
|
||||
import { z } from '@kbn/zod';
|
||||
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
|
||||
import type { AIAssistantKnowledgeBaseDataClient } from '@kbn/elastic-assistant-plugin/server/ai_assistant_data_clients/knowledge_base';
|
||||
import { Document } from 'langchain/document';
|
||||
import type { ContentReferencesStore } from '@kbn/elastic-assistant-common';
|
||||
import { knowledgeBaseReference, contentReferenceBlock } from '@kbn/elastic-assistant-common';
|
||||
import type { Require } from '@kbn/elastic-assistant-plugin/server/types';
|
||||
import { APP_UI_ID } from '../../../../common';
|
||||
|
||||
export interface KnowledgeBaseRetrievalToolParams extends AssistantToolParams {
|
||||
kbDataClient: AIAssistantKnowledgeBaseDataClient;
|
||||
}
|
||||
export type KnowledgeBaseRetrievalToolParams = Require<AssistantToolParams, 'kbDataClient'>;
|
||||
|
||||
const toolDetails = {
|
||||
// note: this description is overwritten when `getTool` is called
|
||||
|
@ -34,7 +32,7 @@ export const KNOWLEDGE_BASE_RETRIEVAL_TOOL: AssistantTool = {
|
|||
const { kbDataClient, isEnabledKnowledgeBase } = params;
|
||||
return isEnabledKnowledgeBase && kbDataClient != null;
|
||||
},
|
||||
getTool(params: AssistantToolParams) {
|
||||
async getTool(params: AssistantToolParams) {
|
||||
if (!this.isSupported(params)) return null;
|
||||
|
||||
const { kbDataClient, logger, contentReferencesStore } =
|
||||
|
|
|
@ -8,16 +8,15 @@
|
|||
import { tool } from '@langchain/core/tools';
|
||||
import { z } from '@kbn/zod';
|
||||
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
|
||||
import type { AIAssistantKnowledgeBaseDataClient } from '@kbn/elastic-assistant-plugin/server/ai_assistant_data_clients/knowledge_base';
|
||||
import { DocumentEntryType } from '@kbn/elastic-assistant-common';
|
||||
import type { KnowledgeBaseEntryCreateProps } from '@kbn/elastic-assistant-common';
|
||||
import type { AnalyticsServiceSetup } from '@kbn/core-analytics-server';
|
||||
import type { Require } from '@kbn/elastic-assistant-plugin/server/types';
|
||||
import { APP_UI_ID } from '../../../../common';
|
||||
|
||||
export interface KnowledgeBaseWriteToolParams extends AssistantToolParams {
|
||||
kbDataClient: AIAssistantKnowledgeBaseDataClient;
|
||||
telemetry: AnalyticsServiceSetup;
|
||||
}
|
||||
export type KnowledgeBaseWriteToolParams = Require<
|
||||
AssistantToolParams,
|
||||
'kbDataClient' | 'telemetry'
|
||||
>;
|
||||
|
||||
const toolDetails = {
|
||||
// note: this description is overwritten when `getTool` is called
|
||||
|
@ -35,7 +34,7 @@ export const KNOWLEDGE_BASE_WRITE_TOOL: AssistantTool = {
|
|||
const { isEnabledKnowledgeBase, kbDataClient } = params;
|
||||
return isEnabledKnowledgeBase && kbDataClient != null;
|
||||
},
|
||||
getTool(params: AssistantToolParams) {
|
||||
async getTool(params: AssistantToolParams) {
|
||||
if (!this.isSupported(params)) return null;
|
||||
|
||||
const { telemetry, kbDataClient, logger } = params as KnowledgeBaseWriteToolParams;
|
||||
|
|
|
@ -137,7 +137,7 @@ describe('OpenAndAcknowledgedAlertsTool', () => {
|
|||
});
|
||||
describe('getTool', () => {
|
||||
it('returns a `DynamicTool` with a `func` that calls `esClient.search()` with the expected query', async () => {
|
||||
const tool: DynamicTool = OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL.getTool({
|
||||
const tool: DynamicTool = (await OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL.getTool({
|
||||
alertsIndexPattern,
|
||||
anonymizationFields,
|
||||
onNewReplacements: jest.fn(),
|
||||
|
@ -145,7 +145,7 @@ describe('OpenAndAcknowledgedAlertsTool', () => {
|
|||
request,
|
||||
size: request.body.size,
|
||||
...rest,
|
||||
}) as DynamicTool;
|
||||
})) as DynamicTool;
|
||||
|
||||
await tool.func('');
|
||||
|
||||
|
@ -233,7 +233,7 @@ describe('OpenAndAcknowledgedAlertsTool', () => {
|
|||
});
|
||||
|
||||
it('includes citations', async () => {
|
||||
const tool: DynamicTool = OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL.getTool({
|
||||
const tool: DynamicTool = (await OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL.getTool({
|
||||
alertsIndexPattern,
|
||||
anonymizationFields,
|
||||
onNewReplacements: jest.fn(),
|
||||
|
@ -241,7 +241,7 @@ describe('OpenAndAcknowledgedAlertsTool', () => {
|
|||
request,
|
||||
size: request.body.size,
|
||||
...rest,
|
||||
}) as DynamicTool;
|
||||
})) as DynamicTool;
|
||||
|
||||
(esClient.search as jest.Mock).mockResolvedValue({
|
||||
hits: {
|
||||
|
@ -263,8 +263,8 @@ describe('OpenAndAcknowledgedAlertsTool', () => {
|
|||
expect(result).toContain('Citation,{reference(exampleContentReferenceId)}');
|
||||
});
|
||||
|
||||
it('returns null when alertsIndexPattern is undefined', () => {
|
||||
const tool = OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL.getTool({
|
||||
it('returns null when alertsIndexPattern is undefined', async () => {
|
||||
const tool = await OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL.getTool({
|
||||
// alertsIndexPattern is undefined
|
||||
anonymizationFields,
|
||||
onNewReplacements: jest.fn(),
|
||||
|
@ -277,8 +277,8 @@ describe('OpenAndAcknowledgedAlertsTool', () => {
|
|||
expect(tool).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null when size is undefined', () => {
|
||||
const tool = OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL.getTool({
|
||||
it('returns null when size is undefined', async () => {
|
||||
const tool = await OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL.getTool({
|
||||
alertsIndexPattern,
|
||||
anonymizationFields,
|
||||
onNewReplacements: jest.fn(),
|
||||
|
@ -291,8 +291,8 @@ describe('OpenAndAcknowledgedAlertsTool', () => {
|
|||
expect(tool).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null when size out of range', () => {
|
||||
const tool = OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL.getTool({
|
||||
it('returns null when size out of range', async () => {
|
||||
const tool = await OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL.getTool({
|
||||
alertsIndexPattern,
|
||||
anonymizationFields,
|
||||
onNewReplacements: jest.fn(),
|
||||
|
@ -305,8 +305,8 @@ describe('OpenAndAcknowledgedAlertsTool', () => {
|
|||
expect(tool).toBeNull();
|
||||
});
|
||||
|
||||
it('returns a tool instance with the expected tags', () => {
|
||||
const tool = OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL.getTool({
|
||||
it('returns a tool instance with the expected tags', async () => {
|
||||
const tool = (await OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL.getTool({
|
||||
alertsIndexPattern,
|
||||
anonymizationFields,
|
||||
onNewReplacements: jest.fn(),
|
||||
|
@ -314,7 +314,7 @@ describe('OpenAndAcknowledgedAlertsTool', () => {
|
|||
request,
|
||||
size: request.body.size,
|
||||
...rest,
|
||||
}) as DynamicTool;
|
||||
})) as DynamicTool;
|
||||
|
||||
expect(tool.tags).toEqual(['alerts', 'open-and-acknowledged-alerts']);
|
||||
});
|
||||
|
|
|
@ -19,12 +19,13 @@ import {
|
|||
import { tool } from '@langchain/core/tools';
|
||||
import { requestHasRequiredAnonymizationParams } from '@kbn/elastic-assistant-plugin/server/lib/langchain/helpers';
|
||||
import type { AssistantTool, AssistantToolParams } from '@kbn/elastic-assistant-plugin/server';
|
||||
import type { Require } from '@kbn/elastic-assistant-plugin/server/types';
|
||||
import { APP_UI_ID } from '../../../../common';
|
||||
|
||||
export interface OpenAndAcknowledgedAlertsToolParams extends AssistantToolParams {
|
||||
alertsIndexPattern: string;
|
||||
size: number;
|
||||
}
|
||||
export type OpenAndAcknowledgedAlertsToolParams = Require<
|
||||
AssistantToolParams,
|
||||
'alertsIndexPattern' | 'size'
|
||||
>;
|
||||
|
||||
export const OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL_DESCRIPTION =
|
||||
'Call this for knowledge about the latest n open and acknowledged alerts (sorted by `kibana.alert.risk_score`) in the environment, or when answering questions about open alerts. Do not call this tool for alert count or quantity. The output is an array of the latest n open and acknowledged alerts.';
|
||||
|
@ -50,7 +51,7 @@ export const OPEN_AND_ACKNOWLEDGED_ALERTS_TOOL: AssistantTool = {
|
|||
!sizeIsOutOfRange(size)
|
||||
);
|
||||
},
|
||||
getTool(params: AssistantToolParams) {
|
||||
async getTool(params: AssistantToolParams) {
|
||||
if (!this.isSupported(params)) return null;
|
||||
|
||||
const {
|
||||
|
|
|
@ -57,14 +57,14 @@ describe('ProductDocumentationTool', () => {
|
|||
});
|
||||
|
||||
describe('getTool', () => {
|
||||
it('should return a tool as expected when all required values are present', () => {
|
||||
const tool = PRODUCT_DOCUMENTATION_TOOL.getTool(defaultArgs) as DynamicTool;
|
||||
it('should return a tool as expected when all required values are present', async () => {
|
||||
const tool = (await PRODUCT_DOCUMENTATION_TOOL.getTool(defaultArgs)) as DynamicTool;
|
||||
expect(tool.name).toEqual('ProductDocumentationTool');
|
||||
expect(tool.tags).toEqual(['product-documentation']);
|
||||
});
|
||||
|
||||
it('returns null if llmTasks plugin is not provided', () => {
|
||||
const tool = PRODUCT_DOCUMENTATION_TOOL.getTool({
|
||||
it('returns null if llmTasks plugin is not provided', async () => {
|
||||
const tool = await PRODUCT_DOCUMENTATION_TOOL.getTool({
|
||||
...defaultArgs,
|
||||
llmTasks: undefined,
|
||||
});
|
||||
|
@ -72,8 +72,8 @@ describe('ProductDocumentationTool', () => {
|
|||
expect(tool).toBeNull();
|
||||
});
|
||||
|
||||
it('returns null if connectorId is not provided', () => {
|
||||
const tool = PRODUCT_DOCUMENTATION_TOOL.getTool({
|
||||
it('returns null if connectorId is not provided', async () => {
|
||||
const tool = await PRODUCT_DOCUMENTATION_TOOL.getTool({
|
||||
...defaultArgs,
|
||||
connectorId: undefined,
|
||||
});
|
||||
|
@ -86,7 +86,7 @@ describe('ProductDocumentationTool', () => {
|
|||
retrieveDocumentation.mockResolvedValue({ documents: [] });
|
||||
});
|
||||
it('the tool invokes retrieveDocumentation', async () => {
|
||||
const tool = PRODUCT_DOCUMENTATION_TOOL.getTool(defaultArgs) as DynamicStructuredTool;
|
||||
const tool = (await PRODUCT_DOCUMENTATION_TOOL.getTool(defaultArgs)) as DynamicStructuredTool;
|
||||
|
||||
await tool.func({ query: 'What is Kibana Security?', product: 'kibana' });
|
||||
|
||||
|
@ -101,7 +101,7 @@ describe('ProductDocumentationTool', () => {
|
|||
});
|
||||
|
||||
it('includes citations', async () => {
|
||||
const tool = PRODUCT_DOCUMENTATION_TOOL.getTool(defaultArgs) as DynamicStructuredTool;
|
||||
const tool = (await PRODUCT_DOCUMENTATION_TOOL.getTool(defaultArgs)) as DynamicStructuredTool;
|
||||
|
||||
(retrieveDocumentation as jest.Mock).mockResolvedValue({
|
||||
documents: [
|
||||
|
|
|
@ -15,8 +15,14 @@ import {
|
|||
} from '@kbn/elastic-assistant-common';
|
||||
import type { ContentReferencesStore } from '@kbn/elastic-assistant-common';
|
||||
import type { RetrieveDocumentationResultDoc } from '@kbn/llm-tasks-plugin/server';
|
||||
import type { Require } from '@kbn/elastic-assistant-plugin/server/types';
|
||||
import { APP_UI_ID } from '../../../../common';
|
||||
|
||||
export type ProductDocumentationToolParams = Require<
|
||||
AssistantToolParams,
|
||||
'llmTasks' | 'connectorId'
|
||||
>;
|
||||
|
||||
const toolDetails = {
|
||||
// note: this description is overwritten when `getTool` is called
|
||||
// local definitions exist ../elastic_assistant/server/lib/prompt/tool_prompts.ts
|
||||
|
@ -29,17 +35,14 @@ const toolDetails = {
|
|||
export const PRODUCT_DOCUMENTATION_TOOL: AssistantTool = {
|
||||
...toolDetails,
|
||||
sourceRegister: APP_UI_ID,
|
||||
isSupported: (params: AssistantToolParams): params is AssistantToolParams => {
|
||||
isSupported: (params: AssistantToolParams): params is ProductDocumentationToolParams => {
|
||||
return params.llmTasks != null && params.connectorId != null;
|
||||
},
|
||||
getTool(params: AssistantToolParams) {
|
||||
async getTool(params: AssistantToolParams) {
|
||||
if (!this.isSupported(params)) return null;
|
||||
|
||||
const { connectorId, llmTasks, request, contentReferencesStore } =
|
||||
params as AssistantToolParams;
|
||||
|
||||
// This check is here in order to satisfy TypeScript
|
||||
if (llmTasks == null || connectorId == null) return null;
|
||||
params as ProductDocumentationToolParams;
|
||||
|
||||
return tool(
|
||||
async ({ query, product }) => {
|
||||
|
|
|
@ -53,7 +53,9 @@ In previous publications,`,
|
|||
}),
|
||||
]);
|
||||
|
||||
const tool = SECURITY_LABS_KNOWLEDGE_BASE_TOOL.getTool(defaultArgs) as DynamicStructuredTool;
|
||||
const tool = (await SECURITY_LABS_KNOWLEDGE_BASE_TOOL.getTool(
|
||||
defaultArgs
|
||||
)) as DynamicStructuredTool;
|
||||
|
||||
(contentReferencesStore.add as jest.Mock).mockImplementation(
|
||||
(creator: Parameters<ContentReferencesStore['add']>[0]) => {
|
||||
|
@ -81,7 +83,9 @@ In previous publications,`,
|
|||
pageContent: `hello world`,
|
||||
}),
|
||||
]);
|
||||
const tool = SECURITY_LABS_KNOWLEDGE_BASE_TOOL.getTool(defaultArgs) as DynamicStructuredTool;
|
||||
const tool = (await SECURITY_LABS_KNOWLEDGE_BASE_TOOL.getTool(
|
||||
defaultArgs
|
||||
)) as DynamicStructuredTool;
|
||||
|
||||
(contentReferencesStore.add as jest.Mock).mockImplementation(
|
||||
(creator: Parameters<ContentReferencesStore['add']>[0]) => {
|
||||
|
@ -104,7 +108,9 @@ In previous publications,`,
|
|||
it('Responds with The "AI Assistant knowledge base" needs to be installed... when no docs and no kb install', async () => {
|
||||
getKnowledgeBaseDocumentEntries.mockResolvedValue([]);
|
||||
(getIsKnowledgeBaseInstalled as jest.Mock).mockResolvedValue(false);
|
||||
const tool = SECURITY_LABS_KNOWLEDGE_BASE_TOOL.getTool(defaultArgs) as DynamicStructuredTool;
|
||||
const tool = (await SECURITY_LABS_KNOWLEDGE_BASE_TOOL.getTool(
|
||||
defaultArgs
|
||||
)) as DynamicStructuredTool;
|
||||
|
||||
const result = await tool.func({ query: 'What is Kibana Security?', product: 'kibana' });
|
||||
|
||||
|
@ -113,7 +119,9 @@ In previous publications,`,
|
|||
it('Responds with empty response when no docs and kb is installed', async () => {
|
||||
getKnowledgeBaseDocumentEntries.mockResolvedValue([]);
|
||||
(getIsKnowledgeBaseInstalled as jest.Mock).mockResolvedValue(true);
|
||||
const tool = SECURITY_LABS_KNOWLEDGE_BASE_TOOL.getTool(defaultArgs) as DynamicStructuredTool;
|
||||
const tool = (await SECURITY_LABS_KNOWLEDGE_BASE_TOOL.getTool(
|
||||
defaultArgs
|
||||
)) as DynamicStructuredTool;
|
||||
|
||||
const result = await tool.func({ query: 'What is Kibana Security?', product: 'kibana' });
|
||||
|
||||
|
|
|
@ -18,9 +18,12 @@ import {
|
|||
knowledgeBaseReference,
|
||||
} from '@kbn/elastic-assistant-common/impl/content_references/references';
|
||||
import { Document } from 'langchain/document';
|
||||
import type { Require } from '@kbn/elastic-assistant-plugin/server/types';
|
||||
import { getIsKnowledgeBaseInstalled } from '@kbn/elastic-assistant-plugin/server/routes/helpers';
|
||||
import { APP_UI_ID } from '../../../../common';
|
||||
|
||||
export type SecurityLabsKnowledgeBaseToolParams = Require<AssistantToolParams, 'kbDataClient'>;
|
||||
|
||||
const toolDetails = {
|
||||
// note: this description is overwritten when `getTool` is called
|
||||
// local definitions exist ../elastic_assistant/server/lib/prompt/tool_prompts.ts
|
||||
|
@ -36,15 +39,14 @@ const SECURITY_LABS_BASE_URL = 'https://www.elastic.co/security-labs/';
|
|||
export const SECURITY_LABS_KNOWLEDGE_BASE_TOOL: AssistantTool = {
|
||||
...toolDetails,
|
||||
sourceRegister: APP_UI_ID,
|
||||
isSupported: (params: AssistantToolParams): params is AssistantToolParams => {
|
||||
isSupported: (params: AssistantToolParams): params is SecurityLabsKnowledgeBaseToolParams => {
|
||||
const { kbDataClient, isEnabledKnowledgeBase } = params;
|
||||
return isEnabledKnowledgeBase && kbDataClient != null;
|
||||
},
|
||||
getTool(params: AssistantToolParams) {
|
||||
async getTool(params: AssistantToolParams) {
|
||||
if (!this.isSupported(params)) return null;
|
||||
|
||||
const { kbDataClient, contentReferencesStore } = params as AssistantToolParams;
|
||||
if (kbDataClient == null) return null;
|
||||
const { kbDataClient, contentReferencesStore } = params as SecurityLabsKnowledgeBaseToolParams;
|
||||
|
||||
return tool(
|
||||
async (input) => {
|
||||
|
|
|
@ -255,6 +255,6 @@
|
|||
"@kbn/react-kibana-context-theme",
|
||||
"@kbn/elastic-assistant-shared-state",
|
||||
"@kbn/elastic-assistant-shared-state-plugin",
|
||||
"@kbn/spaces-utils"
|
||||
"@kbn/spaces-utils",
|
||||
]
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue