mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 09:19:04 -04:00
[Security AI] Add Kibana Support for Security AI Prompts Integration (#207138)
This commit is contained in:
parent
b998946003
commit
7af5a8338b
86 changed files with 1920 additions and 398 deletions
|
@ -914,6 +914,15 @@
|
|||
"username"
|
||||
],
|
||||
"search-telemetry": [],
|
||||
"security-ai-prompt": [
|
||||
"description",
|
||||
"model",
|
||||
"prompt",
|
||||
"prompt.default",
|
||||
"promptGroupId",
|
||||
"promptId",
|
||||
"provider"
|
||||
],
|
||||
"security-rule": [
|
||||
"rule_id",
|
||||
"version"
|
||||
|
|
|
@ -3023,6 +3023,33 @@
|
|||
"dynamic": false,
|
||||
"properties": {}
|
||||
},
|
||||
"security-ai-prompt": {
|
||||
"dynamic": false,
|
||||
"properties": {
|
||||
"description": {
|
||||
"type": "text"
|
||||
},
|
||||
"model": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"prompt": {
|
||||
"properties": {
|
||||
"default": {
|
||||
"type": "text"
|
||||
}
|
||||
}
|
||||
},
|
||||
"promptGroupId": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"promptId": {
|
||||
"type": "keyword"
|
||||
},
|
||||
"provider": {
|
||||
"type": "keyword"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security-rule": {
|
||||
"dynamic": false,
|
||||
"properties": {
|
||||
|
|
|
@ -11,4 +11,4 @@ export { registerCoreObjectTypes } from './registration';
|
|||
|
||||
// set minimum number of registered saved objects to ensure no object types are removed after 8.8
|
||||
// declared in internal implementation exclicilty to prevent unintended changes.
|
||||
export const SAVED_OBJECT_TYPES_COUNT = 127 as const;
|
||||
export const SAVED_OBJECT_TYPES_COUNT = 128 as const;
|
||||
|
|
|
@ -157,6 +157,7 @@ describe('checking migration metadata changes on all registered SO types', () =>
|
|||
"search": "0aa6eefb37edd3145be340a8b67779c2ca578b22",
|
||||
"search-session": "b2fcd840e12a45039ada50b1355faeafa39876d1",
|
||||
"search-telemetry": "b568601618744720b5662946d3103e3fb75fe8ee",
|
||||
"security-ai-prompt": "cc8ee5aaa9d001e89c131bbd5af6bc80bc271046",
|
||||
"security-rule": "07abb4d7e707d91675ec0495c73816394c7b521f",
|
||||
"security-solution-signals-migration": "9d99715fe5246f19de2273ba77debd2446c36bb1",
|
||||
"siem-detection-engine-rule-actions": "54f08e23887b20da7c805fab7c60bc67c428aff9",
|
||||
|
|
|
@ -123,6 +123,7 @@ const previouslyRegisteredTypes = [
|
|||
'search',
|
||||
'search-session',
|
||||
'search-telemetry',
|
||||
'security-ai-prompt',
|
||||
'security-rule',
|
||||
'security-solution-signals-migration',
|
||||
'risk-engine-configuration',
|
||||
|
|
|
@ -96,6 +96,7 @@ export {
|
|||
isInferenceRequestError,
|
||||
isInferenceRequestAbortedError,
|
||||
} from './src/errors';
|
||||
export { elasticModelDictionary } from './src/const';
|
||||
|
||||
export { truncateList } from './src/truncate_list';
|
||||
export {
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { ElasticModelDictionary } from './types';
|
||||
|
||||
export const elasticModelDictionary: ElasticModelDictionary = {
|
||||
'rainbow-sprinkles': {
|
||||
provider: 'bedrock',
|
||||
model: 'us.anthropic.claude-3-5-sonnet-20240620-v1:0',
|
||||
},
|
||||
};
|
|
@ -0,0 +1,13 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export interface ElasticModelDictionary {
|
||||
[key: string]: {
|
||||
provider: string;
|
||||
model: string;
|
||||
};
|
||||
}
|
|
@ -251,6 +251,7 @@ export const item: GetInfoResponse['item'] = {
|
|||
index_pattern: [],
|
||||
lens: [],
|
||||
map: [],
|
||||
security_ai_prompt: [],
|
||||
security_rule: [],
|
||||
csp_rule_template: [],
|
||||
tag: [],
|
||||
|
|
|
@ -106,6 +106,7 @@ export const item: GetInfoResponse['item'] = {
|
|||
ml_module: [],
|
||||
osquery_pack_asset: [],
|
||||
osquery_saved_query: [],
|
||||
security_ai_prompt: [],
|
||||
security_rule: [],
|
||||
csp_rule_template: [],
|
||||
tag: [],
|
||||
|
|
|
@ -33,6 +33,7 @@ describe('Fleet - packageToPackagePolicy', () => {
|
|||
map: [],
|
||||
lens: [],
|
||||
ml_module: [],
|
||||
security_ai_prompt: [],
|
||||
security_rule: [],
|
||||
tag: [],
|
||||
osquery_pack_asset: [],
|
||||
|
|
|
@ -59,6 +59,7 @@ export enum KibanaAssetType {
|
|||
indexPattern = 'index_pattern',
|
||||
map = 'map',
|
||||
mlModule = 'ml_module',
|
||||
securityAIPrompt = 'security_ai_prompt',
|
||||
securityRule = 'security_rule',
|
||||
cloudSecurityPostureRuleTemplate = 'csp_rule_template',
|
||||
osqueryPackAsset = 'osquery_pack_asset',
|
||||
|
@ -77,6 +78,7 @@ export enum KibanaSavedObjectType {
|
|||
indexPattern = 'index-pattern',
|
||||
map = 'map',
|
||||
mlModule = 'ml-module',
|
||||
securityAIPrompt = 'security-ai-prompt',
|
||||
securityRule = 'security-rule',
|
||||
cloudSecurityPostureRuleTemplate = 'csp-rule-template',
|
||||
osqueryPackAsset = 'osquery-pack-asset',
|
||||
|
|
|
@ -47,6 +47,12 @@ export const AssetTitleMap: Record<
|
|||
map: i18n.translate('xpack.fleet.epm.assetTitles.maps', {
|
||||
defaultMessage: 'Maps',
|
||||
}),
|
||||
'security-ai-prompt': i18n.translate('xpack.fleet.epm.assetTitles.securityAIPrompt', {
|
||||
defaultMessage: 'Security AI prompt',
|
||||
}),
|
||||
security_ai_prompt: i18n.translate('xpack.fleet.epm.assetTitles.securityAIPrompt', {
|
||||
defaultMessage: 'Security AI prompt',
|
||||
}),
|
||||
'security-rule': i18n.translate('xpack.fleet.epm.assetTitles.securityRules', {
|
||||
defaultMessage: 'Security rules',
|
||||
}),
|
||||
|
|
|
@ -193,6 +193,7 @@ describe('schema validation', () => {
|
|||
map: [],
|
||||
index_pattern: [],
|
||||
ml_module: [],
|
||||
security_ai_prompt: [],
|
||||
security_rule: [],
|
||||
tag: [],
|
||||
csp_rule_template: [],
|
||||
|
|
|
@ -38,6 +38,7 @@ packageInfoCache.set('test_package-0.0.0', {
|
|||
index_pattern: [],
|
||||
map: [],
|
||||
lens: [],
|
||||
security_ai_prompt: [],
|
||||
security_rule: [],
|
||||
ml_module: [],
|
||||
tag: [],
|
||||
|
@ -122,6 +123,7 @@ packageInfoCache.set('osquery_manager-0.3.0', {
|
|||
index_pattern: [],
|
||||
map: [],
|
||||
lens: [],
|
||||
security_ai_prompt: [],
|
||||
security_rule: [],
|
||||
ml_module: [],
|
||||
tag: [],
|
||||
|
@ -172,6 +174,7 @@ packageInfoCache.set('profiler_symbolizer-8.8.0-preview', {
|
|||
index_pattern: [],
|
||||
map: [],
|
||||
lens: [],
|
||||
security_ai_prompt: [],
|
||||
security_rule: [],
|
||||
ml_module: [],
|
||||
tag: [],
|
||||
|
@ -222,6 +225,7 @@ packageInfoCache.set('profiler_collector-8.9.0-preview', {
|
|||
index_pattern: [],
|
||||
map: [],
|
||||
lens: [],
|
||||
security_ai_prompt: [],
|
||||
security_rule: [],
|
||||
ml_module: [],
|
||||
tag: [],
|
||||
|
@ -264,6 +268,7 @@ packageInfoCache.set('apm-8.9.0-preview', {
|
|||
index_pattern: [],
|
||||
map: [],
|
||||
lens: [],
|
||||
security_ai_prompt: [],
|
||||
security_rule: [],
|
||||
ml_module: [],
|
||||
tag: [],
|
||||
|
@ -306,6 +311,7 @@ packageInfoCache.set('elastic_connectors-1.0.0', {
|
|||
index_pattern: [],
|
||||
map: [],
|
||||
lens: [],
|
||||
security_ai_prompt: [],
|
||||
security_rule: [],
|
||||
ml_module: [],
|
||||
tag: [],
|
||||
|
|
|
@ -33,9 +33,10 @@ import { deleteKibanaSavedObjectsAssets } from '../../packages/remove';
|
|||
import { FleetError, KibanaSOReferenceError } from '../../../../errors';
|
||||
import { withPackageSpan } from '../../packages/utils';
|
||||
|
||||
import { appContextService } from '../../..';
|
||||
|
||||
import { tagKibanaAssets } from './tag_assets';
|
||||
import { getSpaceAwareSaveobjectsClients } from './saved_objects';
|
||||
import { appContextService } from '../../..';
|
||||
|
||||
const MAX_ASSETS_TO_INSTALL_IN_PARALLEL = 1000;
|
||||
|
||||
|
@ -70,6 +71,7 @@ export const KibanaSavedObjectTypeMapping: Record<KibanaAssetType, KibanaSavedOb
|
|||
[KibanaAssetType.visualization]: KibanaSavedObjectType.visualization,
|
||||
[KibanaAssetType.lens]: KibanaSavedObjectType.lens,
|
||||
[KibanaAssetType.mlModule]: KibanaSavedObjectType.mlModule,
|
||||
[KibanaAssetType.securityAIPrompt]: KibanaSavedObjectType.securityAIPrompt,
|
||||
[KibanaAssetType.securityRule]: KibanaSavedObjectType.securityRule,
|
||||
[KibanaAssetType.cloudSecurityPostureRuleTemplate]:
|
||||
KibanaSavedObjectType.cloudSecurityPostureRuleTemplate,
|
||||
|
|
|
@ -31,3 +31,6 @@ export const CAPABILITIES = `${BASE_PATH}/capabilities`;
|
|||
Licensing requirements
|
||||
*/
|
||||
export const MINIMUM_AI_ASSISTANT_LICENSE = 'enterprise' as const;
|
||||
|
||||
// Saved Objects
|
||||
export const promptSavedObjectType = 'security-ai-prompt';
|
||||
|
|
|
@ -18,6 +18,19 @@ import type { Logger } from '@kbn/logging';
|
|||
import { ChatPromptTemplate } from '@langchain/core/prompts';
|
||||
import { FakeLLM } from '@langchain/core/utils/testing';
|
||||
import { createOpenAIFunctionsAgent } from 'langchain/agents';
|
||||
import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock';
|
||||
import { savedObjectsClientMock } from '@kbn/core/server/mocks';
|
||||
import {
|
||||
ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
|
||||
ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
|
||||
ATTACK_DISCOVERY_GENERATION_INSIGHTS,
|
||||
ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
|
||||
ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
|
||||
ATTACK_DISCOVERY_GENERATION_TITLE,
|
||||
ATTACK_DISCOVERY_CONTINUE,
|
||||
ATTACK_DISCOVERY_DEFAULT,
|
||||
ATTACK_DISCOVERY_REFINE,
|
||||
} from '../server/lib/prompt/prompts';
|
||||
import { getDefaultAssistantGraph } from '../server/lib/langchain/graphs/default_assistant_graph/graph';
|
||||
import { getDefaultAttackDiscoveryGraph } from '../server/lib/attack_discovery/graphs/default_attack_discovery_graph';
|
||||
|
||||
|
@ -49,11 +62,13 @@ async function getAssistantGraph(logger: Logger): Promise<Drawable> {
|
|||
streamRunnable: false,
|
||||
});
|
||||
const graph = getDefaultAssistantGraph({
|
||||
actionsClient: actionsClientMock.create(),
|
||||
agentRunnable,
|
||||
logger,
|
||||
createLlmInstance,
|
||||
tools: [],
|
||||
replacements: {},
|
||||
savedObjectsClient: savedObjectsClientMock.create(),
|
||||
});
|
||||
return graph.getGraph();
|
||||
}
|
||||
|
@ -67,6 +82,17 @@ async function getAttackDiscoveryGraph(logger: Logger): Promise<Drawable> {
|
|||
llm: mockLlm as unknown as ActionsClientLlm,
|
||||
logger,
|
||||
replacements: {},
|
||||
prompts: {
|
||||
default: ATTACK_DISCOVERY_DEFAULT,
|
||||
refine: ATTACK_DISCOVERY_REFINE,
|
||||
continue: ATTACK_DISCOVERY_CONTINUE,
|
||||
detailsMarkdown: ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
|
||||
entitySummaryMarkdown: ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
|
||||
mitreAttackTactics: ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
|
||||
summaryMarkdown: ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
|
||||
title: ATTACK_DISCOVERY_GENERATION_TITLE,
|
||||
insights: ATTACK_DISCOVERY_GENERATION_INSIGHTS,
|
||||
},
|
||||
size: 20,
|
||||
});
|
||||
|
||||
|
|
|
@ -54,8 +54,8 @@ export const createMockClients = () => {
|
|||
getCurrentUser: jest.fn(),
|
||||
inference: jest.fn(),
|
||||
llmTasks: jest.fn(),
|
||||
savedObjectsClient: core.savedObjects.client,
|
||||
},
|
||||
savedObjectsClient: core.savedObjects.client,
|
||||
|
||||
licensing: {
|
||||
...licensingMock.createRequestHandlerContext({ license }),
|
||||
|
@ -148,6 +148,7 @@ const createElasticAssistantRequestContextMock = (
|
|||
inference: { getClient: jest.fn() },
|
||||
llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() },
|
||||
core: clients.core,
|
||||
savedObjectsClient: clients.elasticAssistant.savedObjectsClient,
|
||||
telemetry: clients.elasticAssistant.telemetry,
|
||||
};
|
||||
};
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import type { Connector } from '@kbn/actions-plugin/server/application/connector/types';
|
||||
import { elasticsearchServiceMock } from '@kbn/core-elasticsearch-server-mocks';
|
||||
|
@ -46,7 +45,6 @@ jest.mock('./helpers/get_evaluator_llm', () => {
|
|||
getEvaluatorLlm: jest.fn().mockResolvedValue(mockLlm),
|
||||
};
|
||||
});
|
||||
|
||||
const actionsClient = {
|
||||
get: jest.fn(),
|
||||
} as unknown as ActionsClient;
|
||||
|
@ -61,7 +59,22 @@ const logger = loggerMock.create();
|
|||
const mockEsClient = elasticsearchServiceMock.createElasticsearchClient();
|
||||
const runName = 'test-run-name';
|
||||
|
||||
const connectors = [mockExperimentConnector];
|
||||
const connectors = [
|
||||
{
|
||||
...mockExperimentConnector,
|
||||
prompts: {
|
||||
default: '',
|
||||
refine: '',
|
||||
continue: '',
|
||||
detailsMarkdown: '',
|
||||
entitySummaryMarkdown: '',
|
||||
mitreAttackTactics: '',
|
||||
summaryMarkdown: '',
|
||||
title: '',
|
||||
insights: '',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
const projectName = 'test-lang-smith-project';
|
||||
|
||||
|
|
|
@ -16,12 +16,16 @@ import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith';
|
|||
import { asyncForEach } from '@kbn/std';
|
||||
import { PublicMethodsOf } from '@kbn/utility-types';
|
||||
|
||||
import { CombinedPrompts } from '../graphs/default_attack_discovery_graph/nodes/helpers/prompts';
|
||||
import { DEFAULT_EVAL_ANONYMIZATION_FIELDS } from './constants';
|
||||
import { AttackDiscoveryGraphMetadata } from '../../langchain/graphs';
|
||||
import { DefaultAttackDiscoveryGraph } from '../graphs/default_attack_discovery_graph';
|
||||
import { getLlmType } from '../../../routes/utils';
|
||||
import { runEvaluations } from './run_evaluations';
|
||||
|
||||
interface ConnectorWithPrompts extends Connector {
|
||||
prompts: CombinedPrompts;
|
||||
}
|
||||
export const evaluateAttackDiscovery = async ({
|
||||
actionsClient,
|
||||
attackDiscoveryGraphs,
|
||||
|
@ -43,7 +47,7 @@ export const evaluateAttackDiscovery = async ({
|
|||
attackDiscoveryGraphs: AttackDiscoveryGraphMetadata[];
|
||||
alertsIndexPattern: string;
|
||||
anonymizationFields?: AnonymizationFieldResponse[];
|
||||
connectors: Connector[];
|
||||
connectors: ConnectorWithPrompts[];
|
||||
connectorTimeout: number;
|
||||
datasetName: string;
|
||||
esClient: ElasticsearchClient;
|
||||
|
@ -96,6 +100,7 @@ export const evaluateAttackDiscovery = async ({
|
|||
esClient,
|
||||
llm,
|
||||
logger,
|
||||
prompts: connector.prompts,
|
||||
size,
|
||||
});
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ const graphState: GraphState = {
|
|||
],
|
||||
combinedGenerations: 'generations',
|
||||
combinedRefinements: 'refinements',
|
||||
continuePrompt: 'continue',
|
||||
errors: [],
|
||||
generationAttempts: 0,
|
||||
generations: [],
|
||||
|
|
|
@ -35,6 +35,7 @@ const graphState: GraphState = {
|
|||
],
|
||||
combinedGenerations: '',
|
||||
combinedRefinements: '',
|
||||
continuePrompt: 'continue',
|
||||
errors: [],
|
||||
generationAttempts: 0,
|
||||
generations: [],
|
||||
|
|
|
@ -23,6 +23,7 @@ const initialGraphState: GraphState = {
|
|||
anonymizedAlerts: [...mockAnonymizedAlerts],
|
||||
combinedGenerations: 'generations',
|
||||
combinedRefinements: '',
|
||||
continuePrompt: 'continue',
|
||||
errors: [],
|
||||
generationAttempts: 2,
|
||||
generations: ['gen', 'erations'],
|
||||
|
|
|
@ -19,6 +19,7 @@ const initialGraphState: GraphState = {
|
|||
anonymizedAlerts: [],
|
||||
combinedGenerations: '',
|
||||
combinedRefinements: '',
|
||||
continuePrompt: 'continue',
|
||||
errors: [],
|
||||
generationAttempts: 0,
|
||||
generations: [],
|
||||
|
|
|
@ -5,13 +5,14 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { ElasticsearchClient, Logger } from '@kbn/core/server';
|
||||
import { ElasticsearchClient, Logger } from '@kbn/core/server';
|
||||
import { Replacements } from '@kbn/elastic-assistant-common';
|
||||
import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen';
|
||||
import type { ActionsClientLlm } from '@kbn/langchain/server';
|
||||
import type { CompiledStateGraph } from '@langchain/langgraph';
|
||||
import { END, START, StateGraph } from '@langchain/langgraph';
|
||||
|
||||
import { CombinedPrompts } from './nodes/helpers/prompts';
|
||||
import { NodeType } from './constants';
|
||||
import { getGenerateOrEndEdge } from './edges/generate_or_end';
|
||||
import { getGenerateOrRefineOrEndEdge } from './edges/generate_or_refine_or_end';
|
||||
|
@ -32,6 +33,7 @@ export interface GetDefaultAttackDiscoveryGraphParams {
|
|||
llm: ActionsClientLlm;
|
||||
logger?: Logger;
|
||||
onNewReplacements?: (replacements: Replacements) => void;
|
||||
prompts: CombinedPrompts;
|
||||
replacements?: Replacements;
|
||||
size: number;
|
||||
start?: string;
|
||||
|
@ -55,6 +57,7 @@ export const getDefaultAttackDiscoveryGraph = ({
|
|||
llm,
|
||||
logger,
|
||||
onNewReplacements,
|
||||
prompts,
|
||||
replacements,
|
||||
size,
|
||||
start,
|
||||
|
@ -64,7 +67,7 @@ export const getDefaultAttackDiscoveryGraph = ({
|
|||
'generate' | 'refine' | 'retrieve_anonymized_alerts' | '__start__'
|
||||
> => {
|
||||
try {
|
||||
const graphState = getDefaultGraphState({ end, filter, start });
|
||||
const graphState = getDefaultGraphState({ end, filter, prompts, start });
|
||||
|
||||
// get nodes:
|
||||
const retrieveAnonymizedAlertsNode = getRetrieveAnonymizedAlertsNode({
|
||||
|
@ -80,11 +83,13 @@ export const getDefaultAttackDiscoveryGraph = ({
|
|||
const generateNode = getGenerateNode({
|
||||
llm,
|
||||
logger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
const refineNode = getRefineNode({
|
||||
llm,
|
||||
logger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
// get edges:
|
||||
|
|
|
@ -31,6 +31,7 @@ const graphState: GraphState = {
|
|||
],
|
||||
combinedGenerations: 'combinedGenerations',
|
||||
combinedRefinements: '',
|
||||
continuePrompt: 'continue',
|
||||
errors: [],
|
||||
generationAttempts: 2,
|
||||
generations: ['combined', 'Generations'],
|
||||
|
|
|
@ -6,13 +6,13 @@
|
|||
*/
|
||||
|
||||
import { getAlertsContextPrompt } from '.';
|
||||
import { getDefaultAttackDiscoveryPrompt } from '../../../helpers/get_default_attack_discovery_prompt';
|
||||
import { ATTACK_DISCOVERY_DEFAULT } from '../../../../../../../prompt/prompts';
|
||||
|
||||
describe('getAlertsContextPrompt', () => {
|
||||
it('generates the correct prompt', () => {
|
||||
const anonymizedAlerts = ['Alert 1', 'Alert 2', 'Alert 3'];
|
||||
|
||||
const expected = `${getDefaultAttackDiscoveryPrompt()}
|
||||
const expected = `${ATTACK_DISCOVERY_DEFAULT}
|
||||
|
||||
Use context from the following alerts to provide insights:
|
||||
|
||||
|
@ -27,7 +27,7 @@ Alert 3
|
|||
|
||||
const prompt = getAlertsContextPrompt({
|
||||
anonymizedAlerts,
|
||||
attackDiscoveryPrompt: getDefaultAttackDiscoveryPrompt(),
|
||||
attackDiscoveryPrompt: ATTACK_DISCOVERY_DEFAULT,
|
||||
});
|
||||
|
||||
expect(prompt).toEqual(expected);
|
||||
|
|
|
@ -16,6 +16,7 @@ const graphState: GraphState = {
|
|||
anonymizedAlerts: mockAnonymizedAlerts, // <-- mockAnonymizedAlerts is an array of objects with a pageContent property
|
||||
combinedGenerations: 'combinedGenerations',
|
||||
combinedRefinements: '',
|
||||
continuePrompt: 'continue',
|
||||
errors: [],
|
||||
generationAttempts: 2,
|
||||
generations: ['combined', 'Generations'],
|
||||
|
|
|
@ -16,13 +16,16 @@ import {
|
|||
} from '../../../../evaluation/__mocks__/mock_anonymized_alerts';
|
||||
import { getAnonymizedAlertsFromState } from './helpers/get_anonymized_alerts_from_state';
|
||||
import { getChainWithFormatInstructions } from '../helpers/get_chain_with_format_instructions';
|
||||
import { getDefaultAttackDiscoveryPrompt } from '../helpers/get_default_attack_discovery_prompt';
|
||||
import { getDefaultRefinePrompt } from '../refine/helpers/get_default_refine_prompt';
|
||||
import { GraphState } from '../../types';
|
||||
import {
|
||||
getParsedAttackDiscoveriesMock,
|
||||
getRawAttackDiscoveriesMock,
|
||||
} from '../../../../../../__mocks__/raw_attack_discoveries';
|
||||
import {
|
||||
ATTACK_DISCOVERY_CONTINUE,
|
||||
ATTACK_DISCOVERY_DEFAULT,
|
||||
ATTACK_DISCOVERY_REFINE,
|
||||
} from '../../../../../prompt/prompts';
|
||||
|
||||
const attackDiscoveryTimestamp = '2024-10-11T17:55:59.702Z';
|
||||
|
||||
|
@ -49,10 +52,11 @@ let mockLlm: ActionsClientLlm;
|
|||
|
||||
const initialGraphState: GraphState = {
|
||||
attackDiscoveries: null,
|
||||
attackDiscoveryPrompt: getDefaultAttackDiscoveryPrompt(),
|
||||
attackDiscoveryPrompt: ATTACK_DISCOVERY_DEFAULT,
|
||||
anonymizedAlerts: [...mockAnonymizedAlerts],
|
||||
combinedGenerations: '',
|
||||
combinedRefinements: '',
|
||||
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
|
||||
errors: [],
|
||||
generationAttempts: 0,
|
||||
generations: [],
|
||||
|
@ -61,13 +65,25 @@ const initialGraphState: GraphState = {
|
|||
maxHallucinationFailures: 5,
|
||||
maxRepeatedGenerations: 3,
|
||||
refinements: [],
|
||||
refinePrompt: getDefaultRefinePrompt(),
|
||||
refinePrompt: ATTACK_DISCOVERY_REFINE,
|
||||
replacements: {
|
||||
...mockAnonymizedAlertsReplacements,
|
||||
},
|
||||
unrefinedResults: null,
|
||||
};
|
||||
|
||||
const prompts = {
|
||||
default: '',
|
||||
refine: '',
|
||||
continue: '',
|
||||
detailsMarkdown: '',
|
||||
entitySummaryMarkdown: '',
|
||||
mitreAttackTactics: '',
|
||||
summaryMarkdown: '',
|
||||
title: '',
|
||||
insights: '',
|
||||
};
|
||||
|
||||
describe('getGenerateNode', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
@ -88,17 +104,20 @@ describe('getGenerateNode', () => {
|
|||
const generateNode = getGenerateNode({
|
||||
llm: mockLlm,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
expect(typeof generateNode).toBe('function');
|
||||
});
|
||||
|
||||
it('invokes the chain with the expected alerts from state and formatting instructions', async () => {
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlm).chain.invoke as jest.Mock;
|
||||
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlm, prompts }).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
const generateNode = getGenerateNode({
|
||||
llm: mockLlm,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
await generateNode(initialGraphState);
|
||||
|
@ -121,7 +140,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
|
|||
'You asked for some JSON, here it is:\n```json\n{"key": "value"}\n```\nI hope that works for you.';
|
||||
|
||||
const mockLlmWithResponse = new FakeLLM({ response }) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(response);
|
||||
|
@ -129,6 +148,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
|
|||
const generateNode = getGenerateNode({
|
||||
llm: mockLlmWithResponse,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
const state = await generateNode(initialGraphState);
|
||||
|
@ -151,14 +171,15 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
|
|||
const mockLlmWithHallucination = new FakeLLM({
|
||||
response: hallucinatedResponse,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithHallucination).chain
|
||||
.invoke as jest.Mock;
|
||||
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithHallucination, prompts })
|
||||
.chain.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(hallucinatedResponse);
|
||||
|
||||
const generateNode = getGenerateNode({
|
||||
llm: mockLlmWithHallucination,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
|
@ -185,14 +206,17 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
|
|||
const mockLlmWithRepeatedGenerations = new FakeLLM({
|
||||
response: repeatedResponse,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithRepeatedGenerations).chain
|
||||
.invoke as jest.Mock;
|
||||
const mockInvoke = getChainWithFormatInstructions({
|
||||
llm: mockLlmWithRepeatedGenerations,
|
||||
prompts,
|
||||
}).chain.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(repeatedResponse);
|
||||
|
||||
const generateNode = getGenerateNode({
|
||||
llm: mockLlmWithRepeatedGenerations,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
|
@ -218,7 +242,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
|
|||
const mockLlmWithResponse = new FakeLLM({
|
||||
response,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(response);
|
||||
|
@ -226,6 +250,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
|
|||
const generateNode = getGenerateNode({
|
||||
llm: mockLlmWithResponse,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
|
@ -257,7 +282,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
|
|||
const mockLlmWithResponse = new FakeLLM({
|
||||
response: secondResponse,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(secondResponse);
|
||||
|
@ -265,6 +290,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
|
|||
const generateNode = getGenerateNode({
|
||||
llm: mockLlmWithResponse,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
|
@ -296,7 +322,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
|
|||
const mockLlmWithResponse = new FakeLLM({
|
||||
response: secondResponse,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(secondResponse);
|
||||
|
@ -304,6 +330,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
|
|||
const generateNode = getGenerateNode({
|
||||
llm: mockLlmWithResponse,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
import type { ActionsClientLlm } from '@kbn/langchain/server';
|
||||
import type { Logger } from '@kbn/core/server';
|
||||
|
||||
import { GenerationPrompts } from '../helpers/prompts';
|
||||
import { discardPreviousGenerations } from './helpers/discard_previous_generations';
|
||||
import { extractJson } from '../helpers/extract_json';
|
||||
import { getAnonymizedAlertsFromState } from './helpers/get_anonymized_alerts_from_state';
|
||||
|
@ -23,9 +24,11 @@ import type { GraphState } from '../../types';
|
|||
export const getGenerateNode = ({
|
||||
llm,
|
||||
logger,
|
||||
prompts,
|
||||
}: {
|
||||
llm: ActionsClientLlm;
|
||||
logger?: Logger;
|
||||
prompts: GenerationPrompts;
|
||||
}): ((state: GraphState) => Promise<GraphState>) => {
|
||||
const generate = async (state: GraphState): Promise<GraphState> => {
|
||||
logger?.debug(() => `---GENERATE---`);
|
||||
|
@ -34,6 +37,7 @@ export const getGenerateNode = ({
|
|||
|
||||
const {
|
||||
attackDiscoveryPrompt,
|
||||
continuePrompt,
|
||||
combinedGenerations,
|
||||
generationAttempts,
|
||||
generations,
|
||||
|
@ -50,9 +54,13 @@ export const getGenerateNode = ({
|
|||
anonymizedAlerts,
|
||||
attackDiscoveryPrompt,
|
||||
combinedMaybePartialResults: combinedGenerations,
|
||||
continuePrompt,
|
||||
});
|
||||
|
||||
const { chain, formatInstructions, llmType } = getChainWithFormatInstructions(llm);
|
||||
const { chain, formatInstructions, llmType } = getChainWithFormatInstructions({
|
||||
llm,
|
||||
prompts,
|
||||
});
|
||||
|
||||
logger?.debug(
|
||||
() => `generate node is invoking the chain (${llmType}), attempt ${generationAttempts}`
|
||||
|
@ -112,6 +120,7 @@ export const getGenerateNode = ({
|
|||
llmType,
|
||||
logger,
|
||||
nodeName: 'generate',
|
||||
prompts,
|
||||
});
|
||||
|
||||
// use the unrefined results if we already reached the max number of retries:
|
||||
|
|
|
@ -13,72 +13,20 @@
|
|||
*/
|
||||
|
||||
import { z } from '@kbn/zod';
|
||||
import { GenerationPrompts } from '../../helpers/prompts';
|
||||
|
||||
export const SYNTAX = '{{ field.name fieldValue1 fieldValue2 fieldValueN }}';
|
||||
const GOOD_SYNTAX_EXAMPLES =
|
||||
'Examples of CORRECT syntax (includes field names and values): {{ host.name hostNameValue }} {{ user.name userNameValue }} {{ source.ip sourceIpValue }}';
|
||||
|
||||
const BAD_SYNTAX_EXAMPLES =
|
||||
'Examples of INCORRECT syntax (bad, because the field names are not included): {{ hostNameValue }} {{ userNameValue }} {{ sourceIpValue }}';
|
||||
|
||||
const RECONNAISSANCE = 'Reconnaissance';
|
||||
const INITIAL_ACCESS = 'Initial Access';
|
||||
const EXECUTION = 'Execution';
|
||||
const PERSISTENCE = 'Persistence';
|
||||
const PRIVILEGE_ESCALATION = 'Privilege Escalation';
|
||||
const DISCOVERY = 'Discovery';
|
||||
const LATERAL_MOVEMENT = 'Lateral Movement';
|
||||
const COMMAND_AND_CONTROL = 'Command and Control';
|
||||
const EXFILTRATION = 'Exfiltration';
|
||||
|
||||
const MITRE_ATTACK_TACTICS = [
|
||||
RECONNAISSANCE,
|
||||
INITIAL_ACCESS,
|
||||
EXECUTION,
|
||||
PERSISTENCE,
|
||||
PRIVILEGE_ESCALATION,
|
||||
DISCOVERY,
|
||||
LATERAL_MOVEMENT,
|
||||
COMMAND_AND_CONTROL,
|
||||
EXFILTRATION,
|
||||
] as const;
|
||||
|
||||
export const AttackDiscoveriesGenerationSchema = z.object({
|
||||
insights: z
|
||||
.array(
|
||||
z.object({
|
||||
alertIds: z.string().array().describe(`The alert IDs that the insight is based on.`),
|
||||
detailsMarkdown: z
|
||||
.string()
|
||||
.describe(
|
||||
`A detailed insight with markdown, where each markdown bullet contains a description of what happened that reads like a story of the attack as it played out and always uses special ${SYNTAX} syntax for field names and values from the source data. ${GOOD_SYNTAX_EXAMPLES} ${BAD_SYNTAX_EXAMPLES}`
|
||||
),
|
||||
entitySummaryMarkdown: z
|
||||
.string()
|
||||
.optional()
|
||||
.describe(
|
||||
`A short (no more than a sentence) summary of the insight featuring only the host.name and user.name fields (when they are applicable), using the same ${SYNTAX} syntax`
|
||||
),
|
||||
mitreAttackTactics: z
|
||||
.string()
|
||||
.array()
|
||||
.optional()
|
||||
.describe(
|
||||
`An array of MITRE ATT&CK tactic for the insight, using one of the following values: ${MITRE_ATTACK_TACTICS.join(
|
||||
','
|
||||
)}`
|
||||
),
|
||||
summaryMarkdown: z
|
||||
.string()
|
||||
.describe(`A markdown summary of insight, using the same ${SYNTAX} syntax`),
|
||||
title: z
|
||||
.string()
|
||||
.describe(
|
||||
'A short, no more than 7 words, title for the insight, NOT formatted with special syntax or markdown. This must be as brief as possible.'
|
||||
),
|
||||
})
|
||||
)
|
||||
.describe(
|
||||
`Insights with markdown that always uses special ${SYNTAX} syntax for field names and values from the source data. ${GOOD_SYNTAX_EXAMPLES} ${BAD_SYNTAX_EXAMPLES}`
|
||||
),
|
||||
});
|
||||
export const getAttackDiscoveriesGenerationSchema = (prompts: GenerationPrompts) =>
|
||||
z.object({
|
||||
insights: z
|
||||
.array(
|
||||
z.object({
|
||||
alertIds: z.string().array().describe(`The alert IDs that the insight is based on.`),
|
||||
detailsMarkdown: z.string().describe(prompts.detailsMarkdown),
|
||||
entitySummaryMarkdown: z.string().optional().describe(prompts.entitySummaryMarkdown),
|
||||
mitreAttackTactics: z.string().array().optional().describe(prompts.mitreAttackTactics),
|
||||
summaryMarkdown: z.string().describe(prompts.summaryMarkdown),
|
||||
title: z.string().describe(prompts.title),
|
||||
})
|
||||
)
|
||||
.describe(prompts.insights),
|
||||
});
|
||||
|
|
|
@ -9,6 +9,23 @@ import { FakeLLM } from '@langchain/core/utils/testing';
|
|||
import type { ActionsClientLlm } from '@kbn/langchain/server';
|
||||
|
||||
import { getChainWithFormatInstructions } from '.';
|
||||
import {
|
||||
ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
|
||||
ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
|
||||
ATTACK_DISCOVERY_GENERATION_INSIGHTS,
|
||||
ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
|
||||
ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
|
||||
ATTACK_DISCOVERY_GENERATION_TITLE,
|
||||
} from '../../../../../../prompt/prompts';
|
||||
|
||||
const prompts = {
|
||||
detailsMarkdown: ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
|
||||
entitySummaryMarkdown: ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
|
||||
mitreAttackTactics: ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
|
||||
summaryMarkdown: ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
|
||||
title: ATTACK_DISCOVERY_GENERATION_TITLE,
|
||||
insights: ATTACK_DISCOVERY_GENERATION_INSIGHTS,
|
||||
};
|
||||
|
||||
describe('getChainWithFormatInstructions', () => {
|
||||
const mockLlm = new FakeLLM({
|
||||
|
@ -32,7 +49,7 @@ Here is the JSON Schema instance your output must adhere to. Include the enclosi
|
|||
\`\`\`
|
||||
`;
|
||||
|
||||
const chainWithFormatInstructions = getChainWithFormatInstructions(mockLlm);
|
||||
const chainWithFormatInstructions = getChainWithFormatInstructions({ llm: mockLlm, prompts });
|
||||
expect(chainWithFormatInstructions).toEqual({
|
||||
chain: expect.any(Object),
|
||||
formatInstructions: expectedFormatInstructions,
|
||||
|
|
|
@ -9,6 +9,7 @@ import type { ActionsClientLlm } from '@kbn/langchain/server';
|
|||
import { ChatPromptTemplate } from '@langchain/core/prompts';
|
||||
import { Runnable } from '@langchain/core/runnables';
|
||||
|
||||
import { GenerationPrompts } from '../prompts';
|
||||
import { getOutputParser } from '../get_output_parser';
|
||||
|
||||
interface GetChainWithFormatInstructions {
|
||||
|
@ -17,10 +18,14 @@ interface GetChainWithFormatInstructions {
|
|||
llmType: string;
|
||||
}
|
||||
|
||||
export const getChainWithFormatInstructions = (
|
||||
llm: ActionsClientLlm
|
||||
): GetChainWithFormatInstructions => {
|
||||
const outputParser = getOutputParser();
|
||||
export const getChainWithFormatInstructions = ({
|
||||
llm,
|
||||
prompts,
|
||||
}: {
|
||||
llm: ActionsClientLlm;
|
||||
prompts: GenerationPrompts;
|
||||
}): GetChainWithFormatInstructions => {
|
||||
const outputParser = getOutputParser(prompts);
|
||||
const formatInstructions = outputParser.getFormatInstructions();
|
||||
|
||||
const prompt = ChatPromptTemplate.fromTemplate(
|
||||
|
|
|
@ -6,12 +6,14 @@
|
|||
*/
|
||||
|
||||
import { getCombinedAttackDiscoveryPrompt } from '.';
|
||||
import { ATTACK_DISCOVERY_CONTINUE } from '../../../../../../prompt/prompts';
|
||||
|
||||
describe('getCombinedAttackDiscoveryPrompt', () => {
|
||||
it('returns the initial query when there are no partial results', () => {
|
||||
const result = getCombinedAttackDiscoveryPrompt({
|
||||
anonymizedAlerts: ['alert1', 'alert2'],
|
||||
attackDiscoveryPrompt: 'attackDiscoveryPrompt',
|
||||
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
|
||||
combinedMaybePartialResults: '',
|
||||
});
|
||||
|
||||
|
@ -31,6 +33,7 @@ alert2
|
|||
const result = getCombinedAttackDiscoveryPrompt({
|
||||
anonymizedAlerts: ['alert1', 'alert2'],
|
||||
attackDiscoveryPrompt: 'attackDiscoveryPrompt',
|
||||
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
|
||||
combinedMaybePartialResults: 'partialResults',
|
||||
});
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@
|
|||
import { isEmpty } from 'lodash/fp';
|
||||
|
||||
import { getAlertsContextPrompt } from '../../generate/helpers/get_alerts_context_prompt';
|
||||
import { getContinuePrompt } from '../get_continue_prompt';
|
||||
|
||||
/**
|
||||
* Returns the the initial query, or the initial query combined with a
|
||||
|
@ -18,11 +17,13 @@ export const getCombinedAttackDiscoveryPrompt = ({
|
|||
anonymizedAlerts,
|
||||
attackDiscoveryPrompt,
|
||||
combinedMaybePartialResults,
|
||||
continuePrompt,
|
||||
}: {
|
||||
anonymizedAlerts: string[];
|
||||
attackDiscoveryPrompt: string;
|
||||
/** combined results that may contain incomplete JSON */
|
||||
combinedMaybePartialResults: string;
|
||||
continuePrompt: string;
|
||||
}): string => {
|
||||
const alertsContextPrompt = getAlertsContextPrompt({
|
||||
anonymizedAlerts,
|
||||
|
@ -33,7 +34,7 @@ export const getCombinedAttackDiscoveryPrompt = ({
|
|||
? alertsContextPrompt // no partial results yet
|
||||
: `${alertsContextPrompt}
|
||||
|
||||
${getContinuePrompt()}
|
||||
${continuePrompt}
|
||||
|
||||
"""
|
||||
${combinedMaybePartialResults}
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getContinuePrompt } from '.';
|
||||
|
||||
describe('getContinuePrompt', () => {
|
||||
it('returns the expected prompt string', () => {
|
||||
const expectedPrompt = `Continue exactly where you left off in the JSON output below, generating only the additional JSON output when it's required to complete your work. The additional JSON output MUST ALWAYS follow these rules:
|
||||
1) it MUST conform to the schema above, because it will be checked against the JSON schema
|
||||
2) it MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds), because it will be parsed as JSON
|
||||
3) it MUST NOT repeat any the previous output, because that would prevent partial results from being combined
|
||||
4) it MUST NOT restart from the beginning, because that would prevent partial results from being combined
|
||||
5) it MUST NOT be prefixed or suffixed with additional text outside of the JSON, because that would prevent it from being combined and parsed as JSON:
|
||||
`;
|
||||
|
||||
expect(getContinuePrompt()).toBe(expectedPrompt);
|
||||
});
|
||||
});
|
|
@ -1,15 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export const getContinuePrompt =
|
||||
(): string => `Continue exactly where you left off in the JSON output below, generating only the additional JSON output when it's required to complete your work. The additional JSON output MUST ALWAYS follow these rules:
|
||||
1) it MUST conform to the schema above, because it will be checked against the JSON schema
|
||||
2) it MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds), because it will be parsed as JSON
|
||||
3) it MUST NOT repeat any the previous output, because that would prevent partial results from being combined
|
||||
4) it MUST NOT restart from the beginning, because that would prevent partial results from being combined
|
||||
5) it MUST NOT be prefixed or suffixed with additional text outside of the JSON, because that would prevent it from being combined and parsed as JSON:
|
||||
`;
|
|
@ -1,16 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getDefaultAttackDiscoveryPrompt } from '.';
|
||||
|
||||
describe('getDefaultAttackDiscoveryPrompt', () => {
|
||||
it('returns the default attack discovery prompt', () => {
|
||||
expect(getDefaultAttackDiscoveryPrompt()).toEqual(
|
||||
"You are a cyber security analyst tasked with analyzing security events from Elastic Security to identify and report on potential cyber attacks or progressions. Your report should focus on high-risk incidents that could severely impact the organization, rather than isolated alerts. Present your findings in a way that can be easily understood by anyone, regardless of their technical expertise, as if you were briefing the CISO. Break down your response into sections based on timing, hosts, and users involved. When correlating alerts, use kibana.alert.original_time when it's available, otherwise use @timestamp. Include appropriate context about the affected hosts and users. Describe how the attack progression might have occurred and, if feasible, attribute it to known threat groups. Prioritize high and critical alerts, but include lower-severity alerts if desired. In the description field, provide as much detail as possible, in a bulleted list explaining any attack progressions. Accuracy is of utmost importance. You MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds)."
|
||||
);
|
||||
});
|
||||
});
|
|
@ -1,9 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export const getDefaultAttackDiscoveryPrompt = (): string =>
|
||||
"You are a cyber security analyst tasked with analyzing security events from Elastic Security to identify and report on potential cyber attacks or progressions. Your report should focus on high-risk incidents that could severely impact the organization, rather than isolated alerts. Present your findings in a way that can be easily understood by anyone, regardless of their technical expertise, as if you were briefing the CISO. Break down your response into sections based on timing, hosts, and users involved. When correlating alerts, use kibana.alert.original_time when it's available, otherwise use @timestamp. Include appropriate context about the affected hosts and users. Describe how the attack progression might have occurred and, if feasible, attribute it to known threat groups. Prioritize high and critical alerts, but include lower-severity alerts if desired. In the description field, provide as much detail as possible, in a bulleted list explaining any attack progressions. Accuracy is of utmost importance. You MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds).";
|
|
@ -5,10 +5,27 @@
|
|||
* 2.0.
|
||||
*/
|
||||
import { getOutputParser } from '.';
|
||||
import {
|
||||
ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
|
||||
ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
|
||||
ATTACK_DISCOVERY_GENERATION_INSIGHTS,
|
||||
ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
|
||||
ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
|
||||
ATTACK_DISCOVERY_GENERATION_TITLE,
|
||||
} from '../../../../../../prompt/prompts';
|
||||
|
||||
const prompts = {
|
||||
detailsMarkdown: ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
|
||||
entitySummaryMarkdown: ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
|
||||
mitreAttackTactics: ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
|
||||
summaryMarkdown: ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
|
||||
title: ATTACK_DISCOVERY_GENERATION_TITLE,
|
||||
insights: ATTACK_DISCOVERY_GENERATION_INSIGHTS,
|
||||
};
|
||||
|
||||
describe('getOutputParser', () => {
|
||||
it('returns a structured output parser with the expected format instructions', () => {
|
||||
const outputParser = getOutputParser();
|
||||
const outputParser = getOutputParser(prompts);
|
||||
|
||||
const expected = `You must format your output as a JSON value that adheres to a given \"JSON Schema\" instance.
|
||||
|
||||
|
|
|
@ -7,7 +7,8 @@
|
|||
|
||||
import { StructuredOutputParser } from 'langchain/output_parsers';
|
||||
|
||||
import { AttackDiscoveriesGenerationSchema } from '../../generate/schema';
|
||||
import { GenerationPrompts } from '../prompts';
|
||||
import { getAttackDiscoveriesGenerationSchema } from '../../generate/schema';
|
||||
|
||||
export const getOutputParser = () =>
|
||||
StructuredOutputParser.fromZodSchema(AttackDiscoveriesGenerationSchema);
|
||||
export const getOutputParser = (prompts: GenerationPrompts) =>
|
||||
StructuredOutputParser.fromZodSchema(getAttackDiscoveriesGenerationSchema(prompts));
|
||||
|
|
|
@ -10,6 +10,15 @@ import type { Logger } from '@kbn/core/server';
|
|||
import { parseCombinedOrThrow } from '.';
|
||||
import { getRawAttackDiscoveriesMock } from '../../../../../../../__mocks__/raw_attack_discoveries';
|
||||
|
||||
const prompts = {
|
||||
detailsMarkdown: '',
|
||||
entitySummaryMarkdown: '',
|
||||
mitreAttackTactics: '',
|
||||
summaryMarkdown: '',
|
||||
title: '',
|
||||
insights: '',
|
||||
};
|
||||
|
||||
describe('parseCombinedOrThrow', () => {
|
||||
const mockLogger: Logger = {
|
||||
debug: jest.fn(),
|
||||
|
@ -28,6 +37,7 @@ describe('parseCombinedOrThrow', () => {
|
|||
nodeName,
|
||||
llmType,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
};
|
||||
|
||||
it('returns an Attack discovery for each insight in a valid combined response', () => {
|
||||
|
|
|
@ -8,9 +8,10 @@
|
|||
import type { Logger } from '@kbn/core/server';
|
||||
import type { AttackDiscovery } from '@kbn/elastic-assistant-common';
|
||||
|
||||
import { GenerationPrompts } from '../prompts';
|
||||
import { addTrailingBackticksIfNecessary } from '../add_trailing_backticks_if_necessary';
|
||||
import { extractJson } from '../extract_json';
|
||||
import { AttackDiscoveriesGenerationSchema } from '../../generate/schema';
|
||||
import { getAttackDiscoveriesGenerationSchema } from '../../generate/schema';
|
||||
|
||||
export const parseCombinedOrThrow = ({
|
||||
combinedResponse,
|
||||
|
@ -18,6 +19,7 @@ export const parseCombinedOrThrow = ({
|
|||
llmType,
|
||||
logger,
|
||||
nodeName,
|
||||
prompts,
|
||||
}: {
|
||||
/** combined responses that maybe valid JSON */
|
||||
combinedResponse: string;
|
||||
|
@ -25,6 +27,7 @@ export const parseCombinedOrThrow = ({
|
|||
nodeName: string;
|
||||
llmType: string;
|
||||
logger?: Logger;
|
||||
prompts: GenerationPrompts;
|
||||
}): AttackDiscovery[] => {
|
||||
const timestamp = new Date().toISOString();
|
||||
|
||||
|
@ -42,7 +45,7 @@ export const parseCombinedOrThrow = ({
|
|||
`${nodeName} node is validating combined response (${llmType}) from attempt ${generationAttempts}`
|
||||
);
|
||||
|
||||
const validatedResponse = AttackDiscoveriesGenerationSchema.parse(unvalidatedParsed);
|
||||
const validatedResponse = getAttackDiscoveriesGenerationSchema(prompts).parse(unvalidatedParsed);
|
||||
|
||||
logger?.debug(
|
||||
() =>
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
|
||||
import { PublicMethodsOf } from '@kbn/utility-types';
|
||||
import { getAttackDiscoveryPrompts } from '.';
|
||||
import { getPromptsByGroupId, promptDictionary } from '../../../../../../prompt';
|
||||
import { promptGroupId } from '../../../../../../prompt/local_prompt_object';
|
||||
|
||||
jest.mock('../../../../../../prompt', () => {
|
||||
const original = jest.requireActual('../../../../../../prompt');
|
||||
return {
|
||||
...original,
|
||||
getPromptsByGroupId: jest.fn(),
|
||||
};
|
||||
});
|
||||
const mockGetPromptsByGroupId = getPromptsByGroupId as jest.Mock;
|
||||
|
||||
describe('getAttackDiscoveryPrompts', () => {
|
||||
const actionsClient = {} as jest.Mocked<PublicMethodsOf<ActionsClient>>;
|
||||
const savedObjectsClient = {} as jest.Mocked<SavedObjectsClientContract>;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
mockGetPromptsByGroupId.mockResolvedValue([
|
||||
{ promptId: promptDictionary.attackDiscoveryDefault, prompt: 'Default Prompt' },
|
||||
{ promptId: promptDictionary.attackDiscoveryRefine, prompt: 'Refine Prompt' },
|
||||
{ promptId: promptDictionary.attackDiscoveryContinue, prompt: 'Continue Prompt' },
|
||||
{
|
||||
promptId: promptDictionary.attackDiscoveryDetailsMarkdown,
|
||||
prompt: 'Details Markdown Prompt',
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.attackDiscoveryEntitySummaryMarkdown,
|
||||
prompt: 'Entity Summary Markdown Prompt',
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.attackDiscoveryMitreAttackTactics,
|
||||
prompt: 'Mitre Attack Tactics Prompt',
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.attackDiscoverySummaryMarkdown,
|
||||
prompt: 'Summary Markdown Prompt',
|
||||
},
|
||||
{ promptId: promptDictionary.attackDiscoveryGenerationTitle, prompt: 'Title Prompt' },
|
||||
{ promptId: promptDictionary.attackDiscoveryGenerationInsights, prompt: 'Insights Prompt' },
|
||||
]);
|
||||
});
|
||||
|
||||
it('should return all prompts', async () => {
|
||||
const result = await getAttackDiscoveryPrompts({
|
||||
actionsClient,
|
||||
connectorId: 'test-connector-id',
|
||||
savedObjectsClient,
|
||||
model: '2',
|
||||
provider: 'gemini',
|
||||
});
|
||||
|
||||
expect(mockGetPromptsByGroupId).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
connectorId: 'test-connector-id',
|
||||
promptGroupId: promptGroupId.attackDiscovery,
|
||||
model: '2',
|
||||
provider: 'gemini',
|
||||
promptIds: [
|
||||
promptDictionary.attackDiscoveryDefault,
|
||||
promptDictionary.attackDiscoveryRefine,
|
||||
promptDictionary.attackDiscoveryContinue,
|
||||
promptDictionary.attackDiscoveryDetailsMarkdown,
|
||||
promptDictionary.attackDiscoveryEntitySummaryMarkdown,
|
||||
promptDictionary.attackDiscoveryMitreAttackTactics,
|
||||
promptDictionary.attackDiscoverySummaryMarkdown,
|
||||
promptDictionary.attackDiscoveryGenerationTitle,
|
||||
promptDictionary.attackDiscoveryGenerationInsights,
|
||||
],
|
||||
})
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
default: 'Default Prompt',
|
||||
refine: 'Refine Prompt',
|
||||
continue: 'Continue Prompt',
|
||||
detailsMarkdown: 'Details Markdown Prompt',
|
||||
entitySummaryMarkdown: 'Entity Summary Markdown Prompt',
|
||||
mitreAttackTactics: 'Mitre Attack Tactics Prompt',
|
||||
summaryMarkdown: 'Summary Markdown Prompt',
|
||||
title: 'Title Prompt',
|
||||
insights: 'Insights Prompt',
|
||||
});
|
||||
});
|
||||
|
||||
it('should return empty strings for missing prompts', async () => {
|
||||
mockGetPromptsByGroupId.mockResolvedValue([]);
|
||||
|
||||
const result = await getAttackDiscoveryPrompts({
|
||||
actionsClient,
|
||||
connectorId: 'test-connector-id',
|
||||
savedObjectsClient,
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
default: '',
|
||||
refine: '',
|
||||
continue: '',
|
||||
detailsMarkdown: '',
|
||||
entitySummaryMarkdown: '',
|
||||
mitreAttackTactics: '',
|
||||
summaryMarkdown: '',
|
||||
title: '',
|
||||
insights: '',
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { PublicMethodsOf } from '@kbn/utility-types';
|
||||
import type { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
|
||||
import type { Connector } from '@kbn/actions-plugin/server/application/connector/types';
|
||||
import { getPromptsByGroupId, promptDictionary } from '../../../../../../prompt';
|
||||
import { promptGroupId } from '../../../../../../prompt/local_prompt_object';
|
||||
|
||||
export interface AttackDiscoveryPrompts {
|
||||
default: string;
|
||||
refine: string;
|
||||
continue: string;
|
||||
}
|
||||
|
||||
export interface GenerationPrompts {
|
||||
detailsMarkdown: string;
|
||||
entitySummaryMarkdown: string;
|
||||
mitreAttackTactics: string;
|
||||
summaryMarkdown: string;
|
||||
title: string;
|
||||
insights: string;
|
||||
}
|
||||
|
||||
export interface CombinedPrompts extends AttackDiscoveryPrompts, GenerationPrompts {}
|
||||
|
||||
export const getAttackDiscoveryPrompts = async ({
|
||||
actionsClient,
|
||||
connector,
|
||||
connectorId,
|
||||
model,
|
||||
provider,
|
||||
savedObjectsClient,
|
||||
}: {
|
||||
actionsClient: PublicMethodsOf<ActionsClient>;
|
||||
connector?: Connector;
|
||||
connectorId: string;
|
||||
model?: string;
|
||||
provider?: string;
|
||||
savedObjectsClient: SavedObjectsClientContract;
|
||||
}): Promise<CombinedPrompts> => {
|
||||
const prompts = await getPromptsByGroupId({
|
||||
actionsClient,
|
||||
connector,
|
||||
connectorId,
|
||||
// if in future oss has different prompt, add it as model here
|
||||
model,
|
||||
promptGroupId: promptGroupId.attackDiscovery,
|
||||
promptIds: [
|
||||
promptDictionary.attackDiscoveryDefault,
|
||||
promptDictionary.attackDiscoveryRefine,
|
||||
promptDictionary.attackDiscoveryContinue,
|
||||
promptDictionary.attackDiscoveryDetailsMarkdown,
|
||||
promptDictionary.attackDiscoveryEntitySummaryMarkdown,
|
||||
promptDictionary.attackDiscoveryMitreAttackTactics,
|
||||
promptDictionary.attackDiscoverySummaryMarkdown,
|
||||
promptDictionary.attackDiscoveryGenerationTitle,
|
||||
promptDictionary.attackDiscoveryGenerationInsights,
|
||||
],
|
||||
provider,
|
||||
savedObjectsClient,
|
||||
});
|
||||
|
||||
return {
|
||||
default:
|
||||
prompts.find((prompt) => prompt.promptId === promptDictionary.attackDiscoveryDefault)
|
||||
?.prompt || '',
|
||||
refine:
|
||||
prompts.find((prompt) => prompt.promptId === promptDictionary.attackDiscoveryRefine)
|
||||
?.prompt || '',
|
||||
continue:
|
||||
prompts.find((prompt) => prompt.promptId === promptDictionary.attackDiscoveryContinue)
|
||||
?.prompt || '',
|
||||
detailsMarkdown:
|
||||
prompts.find((prompt) => prompt.promptId === promptDictionary.attackDiscoveryDetailsMarkdown)
|
||||
?.prompt || '',
|
||||
entitySummaryMarkdown:
|
||||
prompts.find(
|
||||
(prompt) => prompt.promptId === promptDictionary.attackDiscoveryEntitySummaryMarkdown
|
||||
)?.prompt || '',
|
||||
mitreAttackTactics:
|
||||
prompts.find(
|
||||
(prompt) => prompt.promptId === promptDictionary.attackDiscoveryMitreAttackTactics
|
||||
)?.prompt || '',
|
||||
summaryMarkdown:
|
||||
prompts.find((prompt) => prompt.promptId === promptDictionary.attackDiscoverySummaryMarkdown)
|
||||
?.prompt || '',
|
||||
title:
|
||||
prompts.find((prompt) => prompt.promptId === promptDictionary.attackDiscoveryGenerationTitle)
|
||||
?.prompt || '',
|
||||
insights:
|
||||
prompts.find(
|
||||
(prompt) => prompt.promptId === promptDictionary.attackDiscoveryGenerationInsights
|
||||
)?.prompt || '',
|
||||
};
|
||||
};
|
|
@ -15,6 +15,7 @@ const initialState: GraphState = {
|
|||
attackDiscoveryPrompt: 'attackDiscoveryPrompt',
|
||||
combinedGenerations: 'generation1generation2',
|
||||
combinedRefinements: 'refinement1', // <-- existing refinements
|
||||
continuePrompt: 'continue',
|
||||
errors: [],
|
||||
generationAttempts: 3,
|
||||
generations: ['generation1', 'generation2'],
|
||||
|
|
|
@ -7,13 +7,14 @@
|
|||
|
||||
import { getCombinedRefinePrompt } from '.';
|
||||
import { mockAttackDiscoveries } from '../../../../../../evaluation/__mocks__/mock_attack_discoveries';
|
||||
import { getContinuePrompt } from '../../../helpers/get_continue_prompt';
|
||||
import { ATTACK_DISCOVERY_CONTINUE } from '../../../../../../../prompt/prompts';
|
||||
|
||||
describe('getCombinedRefinePrompt', () => {
|
||||
it('returns the base query when combinedRefinements is empty', () => {
|
||||
const result = getCombinedRefinePrompt({
|
||||
attackDiscoveryPrompt: 'Initial query',
|
||||
combinedRefinements: '',
|
||||
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
|
||||
refinePrompt: 'Refine prompt',
|
||||
unrefinedResults: [...mockAttackDiscoveries],
|
||||
});
|
||||
|
@ -33,6 +34,7 @@ ${JSON.stringify(mockAttackDiscoveries, null, 2)}
|
|||
const result = getCombinedRefinePrompt({
|
||||
attackDiscoveryPrompt: 'Initial query',
|
||||
combinedRefinements: 'Combined refinements',
|
||||
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
|
||||
refinePrompt: 'Refine prompt',
|
||||
unrefinedResults: [...mockAttackDiscoveries],
|
||||
});
|
||||
|
@ -47,7 +49,7 @@ ${JSON.stringify(mockAttackDiscoveries, null, 2)}
|
|||
|
||||
|
||||
|
||||
${getContinuePrompt()}
|
||||
${ATTACK_DISCOVERY_CONTINUE}
|
||||
|
||||
"""
|
||||
Combined refinements
|
||||
|
@ -60,6 +62,7 @@ Combined refinements
|
|||
const result = getCombinedRefinePrompt({
|
||||
attackDiscoveryPrompt: 'Initial query',
|
||||
combinedRefinements: '',
|
||||
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
|
||||
refinePrompt: 'Refine prompt',
|
||||
unrefinedResults: null,
|
||||
});
|
||||
|
|
|
@ -8,19 +8,19 @@
|
|||
import type { AttackDiscovery } from '@kbn/elastic-assistant-common';
|
||||
import { isEmpty } from 'lodash/fp';
|
||||
|
||||
import { getContinuePrompt } from '../../../helpers/get_continue_prompt';
|
||||
|
||||
/**
|
||||
* Returns a prompt that combines the initial query, a refine prompt, and partial results
|
||||
*/
|
||||
export const getCombinedRefinePrompt = ({
|
||||
attackDiscoveryPrompt,
|
||||
combinedRefinements,
|
||||
continuePrompt,
|
||||
refinePrompt,
|
||||
unrefinedResults,
|
||||
}: {
|
||||
attackDiscoveryPrompt: string;
|
||||
combinedRefinements: string;
|
||||
continuePrompt: string;
|
||||
refinePrompt: string;
|
||||
unrefinedResults: AttackDiscovery[] | null;
|
||||
}): string => {
|
||||
|
@ -38,7 +38,7 @@ ${JSON.stringify(unrefinedResults, null, 2)}
|
|||
? baseQuery // no partial results yet
|
||||
: `${baseQuery}
|
||||
|
||||
${getContinuePrompt()}
|
||||
${continuePrompt}
|
||||
|
||||
"""
|
||||
${combinedRefinements}
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getDefaultRefinePrompt } from '.';
|
||||
|
||||
describe('getDefaultRefinePrompt', () => {
|
||||
it('returns the default refine prompt string', () => {
|
||||
const result = getDefaultRefinePrompt();
|
||||
|
||||
expect(result)
|
||||
.toEqual(`You previously generated the following insights, but sometimes they represent the same attack.
|
||||
|
||||
Combine the insights below, when they represent the same attack; leave any insights that are not combined unchanged:`);
|
||||
});
|
||||
});
|
|
@ -1,11 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export const getDefaultRefinePrompt =
|
||||
(): string => `You previously generated the following insights, but sometimes they represent the same attack.
|
||||
|
||||
Combine the insights below, when they represent the same attack; leave any insights that are not combined unchanged:`;
|
|
@ -16,16 +16,26 @@ import {
|
|||
mockAnonymizedAlertsReplacements,
|
||||
} from '../../../../evaluation/__mocks__/mock_anonymized_alerts';
|
||||
import { getChainWithFormatInstructions } from '../helpers/get_chain_with_format_instructions';
|
||||
import { getDefaultAttackDiscoveryPrompt } from '../helpers/get_default_attack_discovery_prompt';
|
||||
import { getDefaultRefinePrompt } from './helpers/get_default_refine_prompt';
|
||||
import { GraphState } from '../../types';
|
||||
import {
|
||||
getParsedAttackDiscoveriesMock,
|
||||
getRawAttackDiscoveriesMock,
|
||||
} from '../../../../../../__mocks__/raw_attack_discoveries';
|
||||
import {
|
||||
ATTACK_DISCOVERY_CONTINUE,
|
||||
ATTACK_DISCOVERY_DEFAULT,
|
||||
ATTACK_DISCOVERY_REFINE,
|
||||
} from '../../../../../prompt/prompts';
|
||||
|
||||
const attackDiscoveryTimestamp = '2024-10-11T17:55:59.702Z';
|
||||
|
||||
const prompts = {
|
||||
detailsMarkdown: '',
|
||||
entitySummaryMarkdown: '',
|
||||
mitreAttackTactics: '',
|
||||
summaryMarkdown: '',
|
||||
title: '',
|
||||
insights: '',
|
||||
};
|
||||
export const mockUnrefinedAttackDiscoveries: AttackDiscovery[] = [
|
||||
{
|
||||
title: 'unrefinedTitle1',
|
||||
|
@ -67,10 +77,11 @@ let mockLlm: ActionsClientLlm;
|
|||
|
||||
const initialGraphState: GraphState = {
|
||||
attackDiscoveries: null,
|
||||
attackDiscoveryPrompt: getDefaultAttackDiscoveryPrompt(),
|
||||
attackDiscoveryPrompt: ATTACK_DISCOVERY_DEFAULT,
|
||||
anonymizedAlerts: [...mockAnonymizedAlerts],
|
||||
combinedGenerations: 'gen1',
|
||||
combinedRefinements: '',
|
||||
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
|
||||
errors: [],
|
||||
generationAttempts: 1,
|
||||
generations: ['gen1'],
|
||||
|
@ -79,7 +90,7 @@ const initialGraphState: GraphState = {
|
|||
maxHallucinationFailures: 5,
|
||||
maxRepeatedGenerations: 3,
|
||||
refinements: [],
|
||||
refinePrompt: getDefaultRefinePrompt(),
|
||||
refinePrompt: ATTACK_DISCOVERY_REFINE,
|
||||
replacements: {
|
||||
...mockAnonymizedAlertsReplacements,
|
||||
},
|
||||
|
@ -106,17 +117,20 @@ describe('getRefineNode', () => {
|
|||
const refineNode = getRefineNode({
|
||||
llm: mockLlm,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
expect(typeof refineNode).toBe('function');
|
||||
});
|
||||
|
||||
it('invokes the chain with the unrefinedResults from state and format instructions', async () => {
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlm).chain.invoke as jest.Mock;
|
||||
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlm, prompts }).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
const refineNode = getRefineNode({
|
||||
llm: mockLlm,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
await refineNode(initialGraphState);
|
||||
|
@ -125,7 +139,7 @@ describe('getRefineNode', () => {
|
|||
format_instructions: ['mock format instructions'],
|
||||
query: `${initialGraphState.attackDiscoveryPrompt}
|
||||
|
||||
${getDefaultRefinePrompt()}
|
||||
${ATTACK_DISCOVERY_REFINE}
|
||||
|
||||
\"\"\"
|
||||
${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
|
||||
|
@ -140,7 +154,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
|
|||
'You asked for some JSON, here it is:\n```json\n{"key": "value"}\n```\nI hope that works for you.';
|
||||
|
||||
const mockLlmWithResponse = new FakeLLM({ response }) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(response);
|
||||
|
@ -148,6 +162,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
|
|||
const refineNode = getRefineNode({
|
||||
llm: mockLlm,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
const state = await refineNode(initialGraphState);
|
||||
|
@ -170,14 +185,15 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
|
|||
const mockLlmWithHallucination = new FakeLLM({
|
||||
response: hallucinatedResponse,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithHallucination).chain
|
||||
.invoke as jest.Mock;
|
||||
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithHallucination, prompts })
|
||||
.chain.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(hallucinatedResponse);
|
||||
|
||||
const refineNode = getRefineNode({
|
||||
llm: mockLlmWithHallucination,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
|
@ -203,14 +219,17 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
|
|||
const mockLlmWithRepeatedGenerations = new FakeLLM({
|
||||
response: repeatedResponse,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithRepeatedGenerations).chain
|
||||
.invoke as jest.Mock;
|
||||
const mockInvoke = getChainWithFormatInstructions({
|
||||
llm: mockLlmWithRepeatedGenerations,
|
||||
prompts,
|
||||
}).chain.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(repeatedResponse);
|
||||
|
||||
const refineNode = getRefineNode({
|
||||
llm: mockLlmWithRepeatedGenerations,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
|
@ -236,7 +255,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
|
|||
const mockLlmWithResponse = new FakeLLM({
|
||||
response,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(response);
|
||||
|
@ -244,6 +263,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
|
|||
const refineNode = getRefineNode({
|
||||
llm: mockLlmWithResponse,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
|
@ -275,7 +295,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
|
|||
const mockLlmWithResponse = new FakeLLM({
|
||||
response: secondResponse,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(secondResponse);
|
||||
|
@ -283,6 +303,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
|
|||
const refineNode = getRefineNode({
|
||||
llm: mockLlmWithResponse,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
|
@ -309,7 +330,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
|
|||
const mockLlmWithResponse = new FakeLLM({
|
||||
response,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(response);
|
||||
|
@ -317,6 +338,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
|
|||
const refineNode = getRefineNode({
|
||||
llm: mockLlmWithResponse,
|
||||
logger: mockLogger,
|
||||
prompts,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
|
|
|
@ -6,8 +6,9 @@
|
|||
*/
|
||||
|
||||
import type { ActionsClientLlm } from '@kbn/langchain/server';
|
||||
import type { Logger } from '@kbn/core/server';
|
||||
import { Logger } from '@kbn/core/server';
|
||||
|
||||
import { GenerationPrompts } from '../helpers/prompts';
|
||||
import { discardPreviousRefinements } from './helpers/discard_previous_refinements';
|
||||
import { extractJson } from '../helpers/extract_json';
|
||||
import { getChainWithFormatInstructions } from '../helpers/get_chain_with_format_instructions';
|
||||
|
@ -24,9 +25,11 @@ import type { GraphState } from '../../types';
|
|||
export const getRefineNode = ({
|
||||
llm,
|
||||
logger,
|
||||
prompts,
|
||||
}: {
|
||||
llm: ActionsClientLlm;
|
||||
logger?: Logger;
|
||||
prompts: GenerationPrompts;
|
||||
}): ((state: GraphState) => Promise<GraphState>) => {
|
||||
const refine = async (state: GraphState): Promise<GraphState> => {
|
||||
logger?.debug(() => '---REFINE---');
|
||||
|
@ -34,6 +37,7 @@ export const getRefineNode = ({
|
|||
const {
|
||||
attackDiscoveryPrompt,
|
||||
combinedRefinements,
|
||||
continuePrompt,
|
||||
generationAttempts,
|
||||
hallucinationFailures,
|
||||
maxGenerationAttempts,
|
||||
|
@ -51,11 +55,15 @@ export const getRefineNode = ({
|
|||
const query = getCombinedRefinePrompt({
|
||||
attackDiscoveryPrompt,
|
||||
combinedRefinements,
|
||||
continuePrompt,
|
||||
refinePrompt,
|
||||
unrefinedResults,
|
||||
});
|
||||
|
||||
const { chain, formatInstructions, llmType } = getChainWithFormatInstructions(llm);
|
||||
const { chain, formatInstructions, llmType } = getChainWithFormatInstructions({
|
||||
llm,
|
||||
prompts,
|
||||
});
|
||||
|
||||
logger?.debug(
|
||||
() => `refine node is invoking the chain (${llmType}), attempt ${generationAttempts}`
|
||||
|
@ -115,6 +123,7 @@ export const getRefineNode = ({
|
|||
llmType,
|
||||
logger,
|
||||
nodeName: 'refine',
|
||||
prompts,
|
||||
});
|
||||
|
||||
return {
|
||||
|
|
|
@ -11,16 +11,20 @@ import { Replacements } from '@kbn/elastic-assistant-common';
|
|||
|
||||
import { getRetrieveAnonymizedAlertsNode } from '.';
|
||||
import { mockAnonymizedAlerts } from '../../../../evaluation/__mocks__/mock_anonymized_alerts';
|
||||
import { getDefaultAttackDiscoveryPrompt } from '../helpers/get_default_attack_discovery_prompt';
|
||||
import { getDefaultRefinePrompt } from '../refine/helpers/get_default_refine_prompt';
|
||||
import type { GraphState } from '../../types';
|
||||
import {
|
||||
ATTACK_DISCOVERY_CONTINUE,
|
||||
ATTACK_DISCOVERY_DEFAULT,
|
||||
ATTACK_DISCOVERY_REFINE,
|
||||
} from '../../../../../prompt/prompts';
|
||||
|
||||
const initialGraphState: GraphState = {
|
||||
attackDiscoveries: null,
|
||||
attackDiscoveryPrompt: getDefaultAttackDiscoveryPrompt(),
|
||||
attackDiscoveryPrompt: ATTACK_DISCOVERY_DEFAULT,
|
||||
anonymizedAlerts: [],
|
||||
combinedGenerations: '',
|
||||
combinedRefinements: '',
|
||||
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
|
||||
errors: [],
|
||||
generationAttempts: 0,
|
||||
generations: [],
|
||||
|
@ -29,7 +33,7 @@ const initialGraphState: GraphState = {
|
|||
maxHallucinationFailures: 5,
|
||||
maxRepeatedGenerations: 3,
|
||||
refinements: [],
|
||||
refinePrompt: getDefaultRefinePrompt(),
|
||||
refinePrompt: ATTACK_DISCOVERY_REFINE,
|
||||
replacements: {},
|
||||
unrefinedResults: null,
|
||||
};
|
||||
|
|
|
@ -11,110 +11,117 @@ import {
|
|||
DEFAULT_MAX_HALLUCINATION_FAILURES,
|
||||
DEFAULT_MAX_REPEATED_GENERATIONS,
|
||||
} from '../constants';
|
||||
import { getDefaultAttackDiscoveryPrompt } from '../nodes/helpers/get_default_attack_discovery_prompt';
|
||||
import { getDefaultRefinePrompt } from '../nodes/refine/helpers/get_default_refine_prompt';
|
||||
|
||||
const defaultAttackDiscoveryPrompt = getDefaultAttackDiscoveryPrompt();
|
||||
const defaultRefinePrompt = getDefaultRefinePrompt();
|
||||
import {
|
||||
ATTACK_DISCOVERY_CONTINUE,
|
||||
ATTACK_DISCOVERY_DEFAULT,
|
||||
ATTACK_DISCOVERY_REFINE,
|
||||
} from '../../../../prompt/prompts';
|
||||
|
||||
const defaultAttackDiscoveryPrompt = ATTACK_DISCOVERY_DEFAULT;
|
||||
const defaultRefinePrompt = ATTACK_DISCOVERY_REFINE;
|
||||
const prompts = {
|
||||
continue: ATTACK_DISCOVERY_CONTINUE,
|
||||
default: defaultAttackDiscoveryPrompt,
|
||||
refine: defaultRefinePrompt,
|
||||
};
|
||||
describe('getDefaultGraphState', () => {
|
||||
it('returns the expected default attackDiscoveries', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.attackDiscoveries?.default?.()).toBeNull();
|
||||
});
|
||||
|
||||
it('returns the expected default attackDiscoveryPrompt', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.attackDiscoveryPrompt?.default?.()).toEqual(defaultAttackDiscoveryPrompt);
|
||||
});
|
||||
|
||||
it('returns the expected default empty collection of anonymizedAlerts', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.anonymizedAlerts?.default?.()).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('returns the expected default combinedGenerations state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.combinedGenerations?.default?.()).toBe('');
|
||||
});
|
||||
|
||||
it('returns the expected default combinedRefinements state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.combinedRefinements?.default?.()).toBe('');
|
||||
});
|
||||
|
||||
it('returns the expected default errors state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.errors?.default?.()).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('return the expected default generationAttempts state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.generationAttempts?.default?.()).toBe(0);
|
||||
});
|
||||
|
||||
it('returns the expected default generations state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.generations?.default?.()).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('returns the expected default hallucinationFailures state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.hallucinationFailures?.default?.()).toBe(0);
|
||||
});
|
||||
|
||||
it('returns the expected default refinePrompt state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.refinePrompt?.default?.()).toEqual(defaultRefinePrompt);
|
||||
});
|
||||
|
||||
it('returns the expected default maxGenerationAttempts state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.maxGenerationAttempts?.default?.()).toBe(DEFAULT_MAX_GENERATION_ATTEMPTS);
|
||||
});
|
||||
|
||||
it('returns the expected default maxHallucinationFailures state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
expect(state.maxHallucinationFailures?.default?.()).toBe(DEFAULT_MAX_HALLUCINATION_FAILURES);
|
||||
});
|
||||
|
||||
it('returns the expected default maxRepeatedGenerations state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.maxRepeatedGenerations?.default?.()).toBe(DEFAULT_MAX_REPEATED_GENERATIONS);
|
||||
});
|
||||
|
||||
it('returns the expected default refinements state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.refinements?.default?.()).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('returns the expected default replacements state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.replacements?.default?.()).toEqual({});
|
||||
});
|
||||
|
||||
it('returns the expected default unrefinedResults state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.unrefinedResults?.default?.()).toBeNull();
|
||||
});
|
||||
|
||||
it('returns the expected default end', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.end?.default?.()).toBeUndefined();
|
||||
});
|
||||
|
@ -122,13 +129,13 @@ describe('getDefaultGraphState', () => {
|
|||
it('returns the expected end when it is provided', () => {
|
||||
const end = '2025-01-02T00:00:00.000Z';
|
||||
|
||||
const state = getDefaultGraphState({ end });
|
||||
const state = getDefaultGraphState({ prompts, end });
|
||||
|
||||
expect(state.end?.default?.()).toEqual(end);
|
||||
});
|
||||
|
||||
it('returns the expected default filter to be undefined', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.filter?.default?.()).toBeUndefined();
|
||||
});
|
||||
|
@ -155,13 +162,13 @@ describe('getDefaultGraphState', () => {
|
|||
},
|
||||
};
|
||||
|
||||
const state = getDefaultGraphState({ filter });
|
||||
const state = getDefaultGraphState({ prompts, filter });
|
||||
|
||||
expect(state.filter?.default?.()).toEqual(filter);
|
||||
});
|
||||
|
||||
it('returns the expected default start to be undefined', () => {
|
||||
const state = getDefaultGraphState();
|
||||
const state = getDefaultGraphState({ prompts });
|
||||
|
||||
expect(state.start?.default?.()).toBeUndefined();
|
||||
});
|
||||
|
@ -169,7 +176,7 @@ describe('getDefaultGraphState', () => {
|
|||
it('returns the expected start when it is provided', () => {
|
||||
const start = '2025-01-01T00:00:00.000Z';
|
||||
|
||||
const state = getDefaultGraphState({ start });
|
||||
const state = getDefaultGraphState({ prompts, start });
|
||||
|
||||
expect(state.start?.default?.()).toEqual(start);
|
||||
});
|
||||
|
|
|
@ -9,33 +9,34 @@ import { AttackDiscovery, Replacements } from '@kbn/elastic-assistant-common';
|
|||
import type { Document } from '@langchain/core/documents';
|
||||
import type { StateGraphArgs } from '@langchain/langgraph';
|
||||
|
||||
import { AttackDiscoveryPrompts } from '../nodes/helpers/prompts';
|
||||
import {
|
||||
DEFAULT_MAX_GENERATION_ATTEMPTS,
|
||||
DEFAULT_MAX_HALLUCINATION_FAILURES,
|
||||
DEFAULT_MAX_REPEATED_GENERATIONS,
|
||||
} from '../constants';
|
||||
import { getDefaultAttackDiscoveryPrompt } from '../nodes/helpers/get_default_attack_discovery_prompt';
|
||||
import { getDefaultRefinePrompt } from '../nodes/refine/helpers/get_default_refine_prompt';
|
||||
import type { GraphState } from '../types';
|
||||
|
||||
export interface Options {
|
||||
end?: string;
|
||||
filter?: Record<string, unknown> | null;
|
||||
prompts: AttackDiscoveryPrompts;
|
||||
start?: string;
|
||||
}
|
||||
|
||||
export const getDefaultGraphState = ({
|
||||
end,
|
||||
filter,
|
||||
prompts,
|
||||
start,
|
||||
}: Options | undefined = {}): StateGraphArgs<GraphState>['channels'] => ({
|
||||
}: Options): StateGraphArgs<GraphState>['channels'] => ({
|
||||
attackDiscoveries: {
|
||||
value: (x: AttackDiscovery[] | null, y?: AttackDiscovery[] | null) => y ?? x,
|
||||
default: () => null,
|
||||
},
|
||||
attackDiscoveryPrompt: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => getDefaultAttackDiscoveryPrompt(),
|
||||
default: () => prompts.default,
|
||||
},
|
||||
anonymizedAlerts: {
|
||||
value: (x: Document[], y?: Document[]) => y ?? x,
|
||||
|
@ -49,6 +50,10 @@ export const getDefaultGraphState = ({
|
|||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
},
|
||||
continuePrompt: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => prompts.continue,
|
||||
},
|
||||
end: {
|
||||
value: (x?: string | null, y?: string | null) => y ?? x,
|
||||
default: () => end,
|
||||
|
@ -75,7 +80,7 @@ export const getDefaultGraphState = ({
|
|||
},
|
||||
refinePrompt: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => getDefaultRefinePrompt(),
|
||||
default: () => prompts.refine,
|
||||
},
|
||||
maxGenerationAttempts: {
|
||||
value: (x: number, y?: number) => y ?? x,
|
||||
|
|
|
@ -14,6 +14,7 @@ export interface GraphState {
|
|||
anonymizedAlerts: Document[];
|
||||
combinedGenerations: string;
|
||||
combinedRefinements: string;
|
||||
continuePrompt: string;
|
||||
end?: string | null;
|
||||
errors: string[];
|
||||
filter?: Record<string, unknown> | null;
|
||||
|
|
|
@ -18,6 +18,7 @@ import type { InferenceServerStart } from '@kbn/inference-plugin/server';
|
|||
import { AnalyticsServiceSetup } from '@kbn/core-analytics-server';
|
||||
import { TelemetryParams } from '@kbn/langchain/server/tracers/telemetry/telemetry_tracer';
|
||||
import type { LlmTasksPluginStart } from '@kbn/llm-tasks-plugin/server';
|
||||
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
|
||||
import { ResponseBody } from '../types';
|
||||
import type { AssistantTool } from '../../../types';
|
||||
import { AIAssistantKnowledgeBaseDataClient } from '../../../ai_assistant_data_clients/knowledge_base';
|
||||
|
@ -57,6 +58,7 @@ export interface AgentExecutorParams<T extends boolean> {
|
|||
onLlmResponse?: OnLlmResponse;
|
||||
request: KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
|
||||
response?: KibanaResponseFactory;
|
||||
savedObjectsClient: SavedObjectsClientContract;
|
||||
size?: number;
|
||||
systemPrompt?: string;
|
||||
telemetry: AnalyticsServiceSetup;
|
||||
|
|
|
@ -14,6 +14,9 @@ import type { Logger } from '@kbn/logging';
|
|||
import { BaseMessage } from '@langchain/core/messages';
|
||||
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
|
||||
import { ConversationResponse, Replacements } from '@kbn/elastic-assistant-common';
|
||||
import { PublicMethodsOf } from '@kbn/utility-types';
|
||||
import { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
|
||||
import { AgentState, NodeParamsBase } from './types';
|
||||
import { AssistantDataClients } from '../../executors/types';
|
||||
|
||||
|
@ -30,10 +33,12 @@ import { NodeType } from './constants';
|
|||
export const DEFAULT_ASSISTANT_GRAPH_ID = 'Default Security Assistant Graph';
|
||||
|
||||
export interface GetDefaultAssistantGraphParams {
|
||||
actionsClient: PublicMethodsOf<ActionsClient>;
|
||||
agentRunnable: AgentRunnableSequence;
|
||||
dataClients?: AssistantDataClients;
|
||||
createLlmInstance: () => BaseChatModel;
|
||||
logger: Logger;
|
||||
savedObjectsClient: SavedObjectsClientContract;
|
||||
signal?: AbortSignal;
|
||||
tools: StructuredTool[];
|
||||
replacements: Replacements;
|
||||
|
@ -42,10 +47,12 @@ export interface GetDefaultAssistantGraphParams {
|
|||
export type DefaultAssistantGraph = ReturnType<typeof getDefaultAssistantGraph>;
|
||||
|
||||
export const getDefaultAssistantGraph = ({
|
||||
actionsClient,
|
||||
agentRunnable,
|
||||
dataClients,
|
||||
createLlmInstance,
|
||||
logger,
|
||||
savedObjectsClient,
|
||||
// some chat models (bedrock) require a signal to be passed on agent invoke rather than the signal passed to the chat model
|
||||
signal,
|
||||
tools,
|
||||
|
@ -97,6 +104,10 @@ export const getDefaultAssistantGraph = ({
|
|||
value: (x: boolean, y?: boolean) => y ?? x,
|
||||
default: () => false,
|
||||
},
|
||||
connectorId: {
|
||||
value: (x: string, y?: string) => y ?? x,
|
||||
default: () => '',
|
||||
},
|
||||
conversation: {
|
||||
value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
|
||||
y ?? x,
|
||||
|
@ -114,7 +125,9 @@ export const getDefaultAssistantGraph = ({
|
|||
|
||||
// Default node parameters
|
||||
const nodeParams: NodeParamsBase = {
|
||||
actionsClient,
|
||||
logger,
|
||||
savedObjectsClient,
|
||||
};
|
||||
|
||||
// Put together a new graph using default state from above
|
||||
|
|
|
@ -56,6 +56,7 @@ describe('streamGraph', () => {
|
|||
input: 'input',
|
||||
responseLanguage: 'English',
|
||||
llmType: 'openai',
|
||||
connectorId: '123',
|
||||
},
|
||||
logger: mockLogger,
|
||||
onLlmResponse: mockOnLlmResponse,
|
||||
|
|
|
@ -18,6 +18,7 @@ import {
|
|||
createStructuredChatAgent,
|
||||
createToolCallingAgent,
|
||||
} from 'langchain/agents';
|
||||
import { savedObjectsClientMock } from '@kbn/core-saved-objects-api-server-mocks';
|
||||
jest.mock('./graph');
|
||||
jest.mock('./helpers');
|
||||
jest.mock('langchain/agents');
|
||||
|
@ -41,6 +42,13 @@ describe('callAssistantGraph', () => {
|
|||
},
|
||||
};
|
||||
|
||||
const savedObjectsClient = savedObjectsClientMock.create();
|
||||
savedObjectsClient.find = jest.fn().mockResolvedValue({
|
||||
page: 1,
|
||||
per_page: 20,
|
||||
total: 0,
|
||||
saved_objects: [],
|
||||
});
|
||||
const defaultParams = {
|
||||
actionsClient: actionsClientMock.create(),
|
||||
alertsIndexPattern: 'test-pattern',
|
||||
|
@ -60,6 +68,7 @@ describe('callAssistantGraph', () => {
|
|||
onNewReplacements: jest.fn(),
|
||||
replacements: [],
|
||||
request: mockRequest,
|
||||
savedObjectsClient,
|
||||
size: 1,
|
||||
systemPrompt: 'test-prompt',
|
||||
telemetry: {},
|
||||
|
@ -167,6 +176,12 @@ describe('callAssistantGraph', () => {
|
|||
});
|
||||
|
||||
it('creates OpenAIToolsAgent for inference llmType', async () => {
|
||||
defaultParams.actionsClient.get = jest.fn().mockResolvedValue({
|
||||
config: {
|
||||
provider: 'elastic',
|
||||
providerConfig: { model_id: 'rainbow-sprinkles' },
|
||||
},
|
||||
});
|
||||
const params = { ...defaultParams, llmType: 'inference' };
|
||||
await callAssistantGraph(params);
|
||||
|
||||
|
|
|
@ -14,11 +14,14 @@ import {
|
|||
} from 'langchain/agents';
|
||||
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
|
||||
import { TelemetryTracer } from '@kbn/langchain/server/tracers/telemetry';
|
||||
import { promptGroupId } from '../../../prompt/local_prompt_object';
|
||||
import { getModelOrOss } from '../../../prompt/helpers';
|
||||
import { getPrompt, promptDictionary } from '../../../prompt';
|
||||
import { getLlmClass } from '../../../../routes/utils';
|
||||
import { EsAnonymizationFieldsSchema } from '../../../../ai_assistant_data_clients/anonymization_fields/types';
|
||||
import { AssistantToolParams } from '../../../../types';
|
||||
import { AgentExecutor } from '../../executors/types';
|
||||
import { formatPrompt, formatPromptStructured, systemPrompts } from './prompts';
|
||||
import { formatPrompt, formatPromptStructured } from './prompts';
|
||||
import { GraphInputs } from './types';
|
||||
import { getDefaultAssistantGraph } from './graph';
|
||||
import { invokeGraph, streamGraph } from './helpers';
|
||||
|
@ -44,6 +47,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
onNewReplacements,
|
||||
replacements,
|
||||
request,
|
||||
savedObjectsClient,
|
||||
size,
|
||||
systemPrompt,
|
||||
telemetry,
|
||||
|
@ -130,29 +134,36 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
}
|
||||
}
|
||||
|
||||
const defaultSystemPrompt = await getPrompt({
|
||||
actionsClient,
|
||||
connectorId,
|
||||
model: getModelOrOss(llmType, isOssModel, request.body.model),
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: llmType,
|
||||
savedObjectsClient,
|
||||
});
|
||||
|
||||
const agentRunnable =
|
||||
isOpenAI || llmType === 'inference'
|
||||
? await createOpenAIToolsAgent({
|
||||
llm: createLlmInstance(),
|
||||
tools,
|
||||
prompt: formatPrompt(systemPrompts.openai, systemPrompt),
|
||||
prompt: formatPrompt(defaultSystemPrompt, systemPrompt),
|
||||
streamRunnable: isStream,
|
||||
})
|
||||
: llmType && ['bedrock', 'gemini'].includes(llmType)
|
||||
? await createToolCallingAgent({
|
||||
llm: createLlmInstance(),
|
||||
tools,
|
||||
prompt:
|
||||
llmType === 'bedrock'
|
||||
? formatPrompt(systemPrompts.bedrock, systemPrompt)
|
||||
: formatPrompt(systemPrompts.gemini, systemPrompt),
|
||||
prompt: formatPrompt(defaultSystemPrompt, systemPrompt),
|
||||
streamRunnable: isStream,
|
||||
})
|
||||
: // used with OSS models
|
||||
await createStructuredChatAgent({
|
||||
llm: createLlmInstance(),
|
||||
tools,
|
||||
prompt: formatPromptStructured(systemPrompts.structuredChat, systemPrompt),
|
||||
prompt: formatPromptStructured(defaultSystemPrompt, systemPrompt),
|
||||
streamRunnable: isStream,
|
||||
});
|
||||
|
||||
|
@ -174,6 +185,8 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
// we need to pass it like this or streaming does not work for bedrock
|
||||
createLlmInstance,
|
||||
logger,
|
||||
actionsClient,
|
||||
savedObjectsClient,
|
||||
tools,
|
||||
replacements,
|
||||
// some chat models (bedrock) require a signal to be passed on agent invoke rather than the signal passed to the chat model
|
||||
|
@ -182,6 +195,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
|
|||
const inputs: GraphInputs = {
|
||||
responseLanguage,
|
||||
conversationId,
|
||||
connectorId,
|
||||
llmType,
|
||||
isStream,
|
||||
isOssModel,
|
||||
|
|
|
@ -7,7 +7,8 @@
|
|||
|
||||
import { RunnableConfig } from '@langchain/core/runnables';
|
||||
import { AgentRunnableSequence } from 'langchain/dist/agents/agent';
|
||||
import { formatLatestUserMessage } from '../prompts';
|
||||
import { promptGroupId } from '../../../../prompt/local_prompt_object';
|
||||
import { getPrompt, promptDictionary } from '../../../../prompt';
|
||||
import { AgentState, NodeParamsBase } from '../types';
|
||||
import { NodeType } from '../constants';
|
||||
import { AIAssistantKnowledgeBaseDataClient } from '../../../../../ai_assistant_data_clients/knowledge_base';
|
||||
|
@ -34,7 +35,9 @@ const NO_KNOWLEDGE_HISTORY = '[No existing knowledge history]';
|
|||
* @param kbDataClient - Data client for accessing the Knowledge Base on behalf of the current user
|
||||
*/
|
||||
export async function runAgent({
|
||||
actionsClient,
|
||||
logger,
|
||||
savedObjectsClient,
|
||||
state,
|
||||
agentRunnable,
|
||||
config,
|
||||
|
@ -43,6 +46,17 @@ export async function runAgent({
|
|||
logger.debug(() => `${NodeType.AGENT}: Node state:\n${JSON.stringify(state, null, 2)}`);
|
||||
|
||||
const knowledgeHistory = await kbDataClient?.getRequiredKnowledgeBaseDocumentEntries();
|
||||
const userPrompt =
|
||||
state.llmType === 'gemini'
|
||||
? await getPrompt({
|
||||
actionsClient,
|
||||
connectorId: state.connectorId,
|
||||
promptId: promptDictionary.userPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'gemini',
|
||||
savedObjectsClient,
|
||||
})
|
||||
: '';
|
||||
const agentOutcome = await agentRunnable
|
||||
.withConfig({ tags: [AGENT_NODE_TAG], signal: config?.signal })
|
||||
.invoke(
|
||||
|
@ -54,7 +68,7 @@ export async function runAgent({
|
|||
: NO_KNOWLEDGE_HISTORY
|
||||
}`,
|
||||
// prepend any user prompt (gemini)
|
||||
input: formatLatestUserMessage(state.input, state.llmType),
|
||||
input: `${userPrompt}${state.input}`,
|
||||
chat_history: state.messages, // TODO: Message de-dupe with ...state spread
|
||||
},
|
||||
config
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
// TODO determine whether or not system prompts should be i18n'd
|
||||
const YOU_ARE_A_HELPFUL_EXPERT_ASSISTANT =
|
||||
'You are a security analyst and expert in resolving security incidents. Your role is to assist by answering questions about Elastic Security.';
|
||||
const IF_YOU_DONT_KNOW_THE_ANSWER = 'Do not answer questions unrelated to Elastic Security.';
|
||||
export const KNOWLEDGE_HISTORY =
|
||||
'If available, use the Knowledge History provided to try and answer the question. If not provided, you can try and query for additional knowledge via the KnowledgeBaseRetrievalTool.';
|
||||
|
||||
export const DEFAULT_SYSTEM_PROMPT = `${YOU_ARE_A_HELPFUL_EXPERT_ASSISTANT} ${IF_YOU_DONT_KNOW_THE_ANSWER} ${KNOWLEDGE_HISTORY}`;
|
||||
// system prompt from @afirstenberg
|
||||
const BASE_GEMINI_PROMPT =
|
||||
'You are an assistant that is an expert at using tools and Elastic Security, doing your best to use these tools to answer questions or follow instructions. It is very important to use tools to answer the question or follow the instructions rather than coming up with your own answer. Tool calls are good. Sometimes you may need to make several tool calls to accomplish the task or get an answer to the question that was asked. Use as many tool calls as necessary.';
|
||||
const KB_CATCH =
|
||||
'If the knowledge base tool gives empty results, do your best to answer the question from the perspective of an expert security analyst.';
|
||||
export const GEMINI_SYSTEM_PROMPT = `${BASE_GEMINI_PROMPT} ${KB_CATCH}`;
|
||||
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from NaturalLanguageESQLTool as is. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response.`;
|
||||
export const GEMINI_USER_PROMPT = `Now, always using the tools at your disposal, step by step, come up with a response to this request:\n\n`;
|
||||
|
||||
export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. ${KNOWLEDGE_HISTORY} You have access to the following tools:
|
||||
|
||||
{tools}
|
||||
|
||||
The tool action_input should ALWAYS follow the tool JSON schema args.
|
||||
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
\`\`\`
|
||||
|
||||
{{
|
||||
|
||||
"action": $TOOL_NAME,
|
||||
|
||||
"action_input": $TOOL_INPUT
|
||||
|
||||
}}
|
||||
|
||||
\`\`\`
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
|
||||
Thought: consider previous and subsequent steps
|
||||
|
||||
Action:
|
||||
|
||||
\`\`\`
|
||||
|
||||
$JSON_BLOB
|
||||
|
||||
\`\`\`
|
||||
|
||||
Observation: action result
|
||||
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
|
||||
Thought: I know what to respond
|
||||
|
||||
Action:
|
||||
|
||||
\`\`\`
|
||||
|
||||
{{
|
||||
|
||||
"action": "Final Answer",
|
||||
|
||||
"action_input": "Final response to human"}}
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`;
|
|
@ -6,13 +6,6 @@
|
|||
*/
|
||||
|
||||
import { ChatPromptTemplate } from '@langchain/core/prompts';
|
||||
import {
|
||||
BEDROCK_SYSTEM_PROMPT,
|
||||
DEFAULT_SYSTEM_PROMPT,
|
||||
GEMINI_SYSTEM_PROMPT,
|
||||
GEMINI_USER_PROMPT,
|
||||
STRUCTURED_SYSTEM_PROMPT,
|
||||
} from './nodes/translations';
|
||||
|
||||
export const formatPrompt = (prompt: string, additionalPrompt?: string) =>
|
||||
ChatPromptTemplate.fromMessages([
|
||||
|
@ -23,20 +16,6 @@ export const formatPrompt = (prompt: string, additionalPrompt?: string) =>
|
|||
['placeholder', '{agent_scratchpad}'],
|
||||
]);
|
||||
|
||||
export const systemPrompts = {
|
||||
openai: DEFAULT_SYSTEM_PROMPT,
|
||||
bedrock: `${DEFAULT_SYSTEM_PROMPT} ${BEDROCK_SYSTEM_PROMPT}`,
|
||||
// The default prompt overwhelms gemini, do not prepend
|
||||
gemini: GEMINI_SYSTEM_PROMPT,
|
||||
structuredChat: STRUCTURED_SYSTEM_PROMPT,
|
||||
};
|
||||
|
||||
export const openAIFunctionAgentPrompt = formatPrompt(systemPrompts.openai);
|
||||
|
||||
export const bedrockToolCallingAgentPrompt = formatPrompt(systemPrompts.bedrock);
|
||||
|
||||
export const geminiToolCallingAgentPrompt = formatPrompt(systemPrompts.gemini);
|
||||
|
||||
export const formatPromptStructured = (prompt: string, additionalPrompt?: string) =>
|
||||
ChatPromptTemplate.fromMessages([
|
||||
['system', additionalPrompt ? `${prompt}\n\n${additionalPrompt}` : prompt],
|
||||
|
@ -47,18 +26,3 @@ export const formatPromptStructured = (prompt: string, additionalPrompt?: string
|
|||
'{input}\n\n{agent_scratchpad}\n\n(reminder to respond in a JSON blob no matter what)',
|
||||
],
|
||||
]);
|
||||
|
||||
export const structuredChatAgentPrompt = formatPromptStructured(systemPrompts.structuredChat);
|
||||
|
||||
/**
|
||||
* If Gemini is the llmType,
|
||||
* Adds a user prompt for the latest message in a conversation
|
||||
* @param prompt
|
||||
* @param llmType
|
||||
*/
|
||||
export const formatLatestUserMessage = (prompt: string, llmType?: string): string => {
|
||||
if (llmType === 'gemini') {
|
||||
return `${GEMINI_USER_PROMPT}${prompt}`;
|
||||
}
|
||||
return prompt;
|
||||
};
|
||||
|
|
|
@ -9,6 +9,9 @@ import { BaseMessage } from '@langchain/core/messages';
|
|||
import { AgentAction, AgentFinish, AgentStep } from '@langchain/core/agents';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import { ConversationResponse } from '@kbn/elastic-assistant-common';
|
||||
import { PublicMethodsOf } from '@kbn/utility-types';
|
||||
import { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
|
||||
|
||||
export interface AgentStateBase {
|
||||
agentOutcome?: AgentAction | AgentFinish;
|
||||
|
@ -16,6 +19,7 @@ export interface AgentStateBase {
|
|||
}
|
||||
|
||||
export interface GraphInputs {
|
||||
connectorId: string;
|
||||
conversationId?: string;
|
||||
llmType?: string;
|
||||
isStream?: boolean;
|
||||
|
@ -34,10 +38,13 @@ export interface AgentState extends AgentStateBase {
|
|||
isOssModel: boolean;
|
||||
llmType: string;
|
||||
responseLanguage: string;
|
||||
connectorId: string;
|
||||
conversation: ConversationResponse | undefined;
|
||||
conversationId: string;
|
||||
}
|
||||
|
||||
export interface NodeParamsBase {
|
||||
actionsClient: PublicMethodsOf<ActionsClient>;
|
||||
logger: Logger;
|
||||
savedObjectsClient: SavedObjectsClientContract;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,520 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getPrompt, getPromptsByGroupId } from './get_prompt';
|
||||
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
|
||||
import { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import { BEDROCK_SYSTEM_PROMPT, DEFAULT_SYSTEM_PROMPT, GEMINI_USER_PROMPT } from './prompts';
|
||||
import { promptDictionary, promptGroupId } from './local_prompt_object';
|
||||
|
||||
jest.mock('@kbn/core-saved-objects-api-server');
|
||||
jest.mock('@kbn/actions-plugin/server');
|
||||
const defaultConnector = {
|
||||
id: 'mock',
|
||||
name: 'Mock',
|
||||
isPreconfigured: false,
|
||||
isDeprecated: false,
|
||||
isSystemAction: false,
|
||||
actionTypeId: '.inference',
|
||||
};
|
||||
describe('get_prompt', () => {
|
||||
let savedObjectsClient: jest.Mocked<SavedObjectsClientContract>;
|
||||
let actionsClient: jest.Mocked<ActionsClient>;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
savedObjectsClient = {
|
||||
find: jest.fn().mockResolvedValue({
|
||||
page: 1,
|
||||
per_page: 20,
|
||||
total: 3,
|
||||
saved_objects: [
|
||||
{
|
||||
type: 'security-ai-prompt',
|
||||
id: '977b39b8-5bb9-4530-9a39-7aa7084fb5c0',
|
||||
attributes: {
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'openai',
|
||||
model: 'gpt-4o',
|
||||
description: 'Default prompt for AI Assistant system prompt.',
|
||||
prompt: {
|
||||
default: 'Hello world this is a system prompt',
|
||||
},
|
||||
},
|
||||
references: [],
|
||||
managed: false,
|
||||
updated_at: '2025-01-22T18:44:35.271Z',
|
||||
updated_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
|
||||
created_at: '2025-01-22T18:44:35.271Z',
|
||||
created_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
|
||||
version: 'Wzk0MiwxXQ==',
|
||||
coreMigrationVersion: '8.8.0',
|
||||
score: 0.13353139,
|
||||
},
|
||||
{
|
||||
type: 'security-ai-prompt',
|
||||
id: 'd6dacb9b-1029-4c4c-85e1-e4f97b31c7f4',
|
||||
attributes: {
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'openai',
|
||||
description: 'Default prompt for AI Assistant system prompt.',
|
||||
prompt: {
|
||||
default: 'Hello world this is a system prompt no model',
|
||||
},
|
||||
},
|
||||
references: [],
|
||||
managed: false,
|
||||
updated_at: '2025-01-22T19:11:48.806Z',
|
||||
updated_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
|
||||
created_at: '2025-01-22T19:11:48.806Z',
|
||||
created_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
|
||||
version: 'Wzk4MCwxXQ==',
|
||||
coreMigrationVersion: '8.8.0',
|
||||
score: 0.13353139,
|
||||
},
|
||||
{
|
||||
type: 'security-ai-prompt',
|
||||
id: 'd6dacb9b-1029-4c4c-85e1-e4f97b31c7f4',
|
||||
attributes: {
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'bedrock',
|
||||
description: 'Default prompt for AI Assistant system prompt.',
|
||||
prompt: {
|
||||
default: 'Hello world this is a system prompt for bedrock',
|
||||
},
|
||||
},
|
||||
references: [],
|
||||
managed: false,
|
||||
updated_at: '2025-01-22T19:11:48.806Z',
|
||||
updated_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
|
||||
created_at: '2025-01-22T19:11:48.806Z',
|
||||
created_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
|
||||
version: 'Wzk4MCwxXQ==',
|
||||
coreMigrationVersion: '8.8.0',
|
||||
score: 0.13353139,
|
||||
},
|
||||
{
|
||||
type: 'security-ai-prompt',
|
||||
id: 'd6dacb9b-1029-4c4c-85e1-e4f97b31c7f4',
|
||||
attributes: {
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'bedrock',
|
||||
model: 'us.anthropic.claude-3-5-sonnet-20240620-v1:0',
|
||||
description: 'Default prompt for AI Assistant system prompt.',
|
||||
prompt: {
|
||||
default: 'Hello world this is a system prompt for bedrock claude-3-5-sonnet',
|
||||
},
|
||||
},
|
||||
references: [],
|
||||
managed: false,
|
||||
updated_at: '2025-01-22T19:11:48.806Z',
|
||||
updated_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
|
||||
created_at: '2025-01-22T19:11:48.806Z',
|
||||
created_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
|
||||
version: 'Wzk4MCwxXQ==',
|
||||
coreMigrationVersion: '8.8.0',
|
||||
score: 0.13353139,
|
||||
},
|
||||
{
|
||||
type: 'security-ai-prompt',
|
||||
id: 'da530fad-87ce-49c3-a088-08073e5034d6',
|
||||
attributes: {
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
description: 'Default prompt for AI Assistant system prompt.',
|
||||
prompt: {
|
||||
default: 'Hello world this is a system prompt no model, no provider',
|
||||
},
|
||||
},
|
||||
references: [],
|
||||
managed: false,
|
||||
updated_at: '2025-01-22T19:12:12.911Z',
|
||||
updated_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
|
||||
created_at: '2025-01-22T19:12:12.911Z',
|
||||
created_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
|
||||
version: 'Wzk4MiwxXQ==',
|
||||
coreMigrationVersion: '8.8.0',
|
||||
score: 0.13353139,
|
||||
},
|
||||
],
|
||||
}),
|
||||
} as unknown as jest.Mocked<SavedObjectsClientContract>;
|
||||
|
||||
actionsClient = {
|
||||
get: jest.fn().mockResolvedValue({
|
||||
config: {
|
||||
provider: 'openai',
|
||||
providerConfig: { model_id: 'gpt-4o' },
|
||||
},
|
||||
}),
|
||||
} as unknown as jest.Mocked<ActionsClient>;
|
||||
});
|
||||
describe('getPrompt', () => {
|
||||
it('returns the prompt matching provider and model', async () => {
|
||||
const result = await getPrompt({
|
||||
savedObjectsClient,
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'openai',
|
||||
model: 'gpt-4o',
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
expect(actionsClient.get).not.toHaveBeenCalled();
|
||||
|
||||
expect(result).toBe('Hello world this is a system prompt');
|
||||
});
|
||||
|
||||
it('returns the prompt matching provider when model does not have a match', async () => {
|
||||
const result = await getPrompt({
|
||||
savedObjectsClient,
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'openai',
|
||||
model: 'gpt-4o-mini',
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
expect(actionsClient.get).not.toHaveBeenCalled();
|
||||
|
||||
expect(result).toBe('Hello world this is a system prompt no model');
|
||||
});
|
||||
|
||||
it('returns the prompt matching provider when model is not provided', async () => {
|
||||
const result = await getPrompt({
|
||||
savedObjectsClient,
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'openai',
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
expect(actionsClient.get).toHaveBeenCalled();
|
||||
|
||||
expect(result).toBe('Hello world this is a system prompt no model');
|
||||
});
|
||||
|
||||
it('returns the default prompt when there is no match on provider', async () => {
|
||||
const result = await getPrompt({
|
||||
savedObjectsClient,
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'badone',
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
|
||||
expect(result).toBe('Hello world this is a system prompt no model, no provider');
|
||||
});
|
||||
|
||||
it('defaults provider to bedrock if provider is "inference"', async () => {
|
||||
actionsClient.get.mockResolvedValue(defaultConnector);
|
||||
|
||||
const result = await getPrompt({
|
||||
savedObjectsClient,
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'inference',
|
||||
model: 'gpt-4o',
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
|
||||
expect(result).toBe('Hello world this is a system prompt for bedrock');
|
||||
});
|
||||
|
||||
it('returns the expected prompt from when provider is "elastic" and model matches in elasticModelDictionary', async () => {
|
||||
actionsClient.get.mockResolvedValue({
|
||||
...defaultConnector,
|
||||
config: {
|
||||
provider: 'elastic',
|
||||
providerConfig: { model_id: 'rainbow-sprinkles' },
|
||||
},
|
||||
});
|
||||
|
||||
const result = await getPrompt({
|
||||
savedObjectsClient,
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'inference',
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
|
||||
expect(result).toBe('Hello world this is a system prompt for bedrock claude-3-5-sonnet');
|
||||
});
|
||||
|
||||
it('returns the bedrock prompt when provider is "elastic" but model does not match elasticModelDictionary', async () => {
|
||||
actionsClient.get.mockResolvedValue({
|
||||
...defaultConnector,
|
||||
config: {
|
||||
provider: 'elastic',
|
||||
providerConfig: { model_id: 'unknown-model' },
|
||||
},
|
||||
});
|
||||
|
||||
const result = await getPrompt({
|
||||
savedObjectsClient,
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'inference',
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
|
||||
expect(result).toBe('Hello world this is a system prompt for bedrock');
|
||||
});
|
||||
|
||||
it('returns the model prompt when no prompts are found and model is provided', async () => {
|
||||
savedObjectsClient.find.mockResolvedValue({
|
||||
page: 1,
|
||||
per_page: 20,
|
||||
total: 0,
|
||||
saved_objects: [],
|
||||
});
|
||||
|
||||
const result = await getPrompt({
|
||||
savedObjectsClient,
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
actionsClient,
|
||||
provider: 'bedrock',
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
|
||||
expect(result).toBe(BEDROCK_SYSTEM_PROMPT);
|
||||
});
|
||||
|
||||
it('returns the default prompt when no prompts are found', async () => {
|
||||
savedObjectsClient.find.mockResolvedValue({
|
||||
page: 1,
|
||||
per_page: 20,
|
||||
total: 0,
|
||||
saved_objects: [],
|
||||
});
|
||||
|
||||
const result = await getPrompt({
|
||||
savedObjectsClient,
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
|
||||
expect(result).toBe(DEFAULT_SYSTEM_PROMPT);
|
||||
});
|
||||
|
||||
it('throws an error when no prompts are found', async () => {
|
||||
savedObjectsClient.find.mockResolvedValue({
|
||||
page: 1,
|
||||
per_page: 20,
|
||||
total: 0,
|
||||
saved_objects: [],
|
||||
});
|
||||
|
||||
await expect(
|
||||
getPrompt({
|
||||
savedObjectsClient,
|
||||
promptId: 'nonexistent-prompt',
|
||||
promptGroupId: 'nonexistent-group',
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
})
|
||||
).rejects.toThrow(
|
||||
'Prompt not found for promptId: nonexistent-prompt and promptGroupId: nonexistent-group'
|
||||
);
|
||||
});
|
||||
|
||||
it('handles invalid connector configuration gracefully when provider is "inference"', async () => {
|
||||
actionsClient.get.mockResolvedValue({
|
||||
...defaultConnector,
|
||||
config: {},
|
||||
});
|
||||
const result = await getPrompt({
|
||||
savedObjectsClient,
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'inference',
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
|
||||
expect(result).toBe('Hello world this is a system prompt for bedrock');
|
||||
});
|
||||
|
||||
it('retrieves the connector when no model or provider is provided', async () => {
|
||||
actionsClient.get.mockResolvedValue({
|
||||
...defaultConnector,
|
||||
actionTypeId: '.bedrock',
|
||||
config: {
|
||||
defaultModel: 'us.anthropic.claude-3-5-sonnet-20240620-v1:0',
|
||||
},
|
||||
});
|
||||
const result = await getPrompt({
|
||||
savedObjectsClient,
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
expect(actionsClient.get).toHaveBeenCalled();
|
||||
|
||||
expect(result).toBe('Hello world this is a system prompt for bedrock claude-3-5-sonnet');
|
||||
});
|
||||
|
||||
it('retrieves the connector when no model is provided', async () => {
|
||||
actionsClient.get.mockResolvedValue({
|
||||
...defaultConnector,
|
||||
actionTypeId: '.bedrock',
|
||||
config: {
|
||||
defaultModel: 'us.anthropic.claude-3-5-sonnet-20240620-v1:0',
|
||||
},
|
||||
});
|
||||
const result = await getPrompt({
|
||||
savedObjectsClient,
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'bedrock',
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
expect(actionsClient.get).toHaveBeenCalled();
|
||||
|
||||
expect(result).toBe('Hello world this is a system prompt for bedrock claude-3-5-sonnet');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getPromptsByGroupId', () => {
|
||||
it('returns prompts matching the provided promptIds', async () => {
|
||||
const result = await getPromptsByGroupId({
|
||||
savedObjectsClient,
|
||||
promptIds: [promptDictionary.systemPrompt],
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'openai',
|
||||
model: 'gpt-4o',
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
expect(savedObjectsClient.find).toHaveBeenCalledWith({
|
||||
type: 'security-ai-prompt',
|
||||
searchFields: ['promptGroupId'],
|
||||
search: promptGroupId.aiAssistant,
|
||||
});
|
||||
|
||||
expect(result).toEqual([
|
||||
{
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
prompt: 'Hello world this is a system prompt',
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('returns prompts matching the provided promptIds for gemini', async () => {
|
||||
const result = await getPromptsByGroupId({
|
||||
savedObjectsClient,
|
||||
promptIds: [promptDictionary.systemPrompt, promptDictionary.userPrompt],
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'gemini',
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
|
||||
expect(result).toEqual([
|
||||
{
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
prompt: 'Hello world this is a system prompt no model, no provider',
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.userPrompt,
|
||||
prompt: GEMINI_USER_PROMPT,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('returns prompts matching the provided promptIds when connector is given', async () => {
|
||||
const result = await getPromptsByGroupId({
|
||||
savedObjectsClient,
|
||||
promptIds: [promptDictionary.systemPrompt, promptDictionary.userPrompt],
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
connector: {
|
||||
actionTypeId: '.gemini',
|
||||
config: {
|
||||
defaultModel: 'gemini-1.5-pro-002',
|
||||
},
|
||||
id: 'connector-123',
|
||||
name: 'Gemini',
|
||||
isPreconfigured: false,
|
||||
isDeprecated: false,
|
||||
isSystemAction: false,
|
||||
},
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
|
||||
expect(result).toEqual([
|
||||
{
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
prompt: 'Hello world this is a system prompt no model, no provider',
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.userPrompt,
|
||||
prompt: GEMINI_USER_PROMPT,
|
||||
},
|
||||
]);
|
||||
});
|
||||
it('returns prompts matching the provided promptIds when inference connector is given', async () => {
|
||||
const result = await getPromptsByGroupId({
|
||||
savedObjectsClient,
|
||||
promptIds: [promptDictionary.systemPrompt],
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
connector: {
|
||||
actionTypeId: '.inference',
|
||||
config: {
|
||||
provider: 'elastic',
|
||||
providerConfig: { model_id: 'rainbow-sprinkles' },
|
||||
},
|
||||
id: 'connector-123',
|
||||
name: 'Inference',
|
||||
isPreconfigured: false,
|
||||
isDeprecated: false,
|
||||
isSystemAction: false,
|
||||
},
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
});
|
||||
|
||||
expect(result).toEqual([
|
||||
{
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
prompt: 'Hello world this is a system prompt for bedrock',
|
||||
},
|
||||
]);
|
||||
});
|
||||
it('throws an error when a prompt is missing', async () => {
|
||||
savedObjectsClient.find.mockResolvedValue({
|
||||
page: 1,
|
||||
per_page: 20,
|
||||
total: 0,
|
||||
saved_objects: [],
|
||||
});
|
||||
|
||||
await expect(
|
||||
getPromptsByGroupId({
|
||||
savedObjectsClient,
|
||||
promptIds: [promptDictionary.systemPrompt, 'fake-id'],
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
actionsClient,
|
||||
connectorId: 'connector-123',
|
||||
})
|
||||
).rejects.toThrow('Prompt not found for promptId: fake-id and promptGroupId: aiAssistant');
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,222 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
|
||||
import { PublicMethodsOf } from '@kbn/utility-types';
|
||||
import { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import type { Connector } from '@kbn/actions-plugin/server/application/connector/types';
|
||||
import { elasticModelDictionary } from '@kbn/inference-common';
|
||||
import { Prompt } from './types';
|
||||
import { localPrompts } from './local_prompt_object';
|
||||
import { getLlmType } from '../../routes/utils';
|
||||
import { promptSavedObjectType } from '../../../common/constants';
|
||||
interface GetPromptArgs {
|
||||
actionsClient: PublicMethodsOf<ActionsClient>;
|
||||
connector?: Connector;
|
||||
connectorId: string;
|
||||
model?: string;
|
||||
promptId: string;
|
||||
promptGroupId: string;
|
||||
provider?: string;
|
||||
savedObjectsClient: SavedObjectsClientContract;
|
||||
}
|
||||
interface GetPromptsByGroupIdArgs extends Omit<GetPromptArgs, 'promptId'> {
|
||||
promptGroupId: string;
|
||||
promptIds: string[];
|
||||
}
|
||||
|
||||
type PromptArray = Array<{ promptId: string; prompt: string }>;
|
||||
/**
|
||||
* Get prompts by feature (promptGroupId)
|
||||
* provide either model + provider or connector to avoid additional calls to get connector
|
||||
* @param actionsClient - actions client
|
||||
* @param connector - connector, provide if available. No need to provide model and provider in this case
|
||||
* @param connectorId - connector id
|
||||
* @param model - model. No need to provide if connector provided
|
||||
* @param promptGroupId - feature id, should be common across promptIds
|
||||
* @param promptIds - prompt ids with shared promptGroupId
|
||||
* @param provider - provider. No need to provide if connector provided
|
||||
* @param savedObjectsClient - saved objects client
|
||||
*/
|
||||
export const getPromptsByGroupId = async ({
|
||||
actionsClient,
|
||||
connector,
|
||||
connectorId,
|
||||
model: providedModel,
|
||||
promptGroupId,
|
||||
promptIds,
|
||||
provider: providedProvider,
|
||||
savedObjectsClient,
|
||||
}: GetPromptsByGroupIdArgs): Promise<PromptArray> => {
|
||||
const { provider, model } = await resolveProviderAndModel({
|
||||
providedProvider,
|
||||
providedModel,
|
||||
connectorId,
|
||||
actionsClient,
|
||||
providedConnector: connector,
|
||||
});
|
||||
|
||||
const prompts = await savedObjectsClient.find<Prompt>({
|
||||
type: promptSavedObjectType,
|
||||
searchFields: ['promptGroupId'],
|
||||
search: promptGroupId,
|
||||
});
|
||||
const promptsOnly = prompts?.saved_objects.map((p) => p.attributes) ?? [];
|
||||
|
||||
return promptIds.map((promptId) => {
|
||||
const prompt = findPromptEntry({
|
||||
prompts: promptsOnly.filter((p) => p.promptId === promptId) ?? [],
|
||||
promptId,
|
||||
promptGroupId,
|
||||
provider,
|
||||
model,
|
||||
});
|
||||
if (!prompt) {
|
||||
throw new Error(
|
||||
`Prompt not found for promptId: ${promptId} and promptGroupId: ${promptGroupId}`
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
promptId,
|
||||
prompt,
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Get prompt by promptId
|
||||
* provide either model + provider or connector to avoid additional calls to get connector
|
||||
* @param actionsClient - actions client
|
||||
* @param connector - connector, provide if available. No need to provide model and provider in this case
|
||||
* @param connectorId - connector id
|
||||
* @param model - model. No need to provide if connector provided
|
||||
* @param promptId - prompt id
|
||||
* @param promptGroupId - feature id, should be common across promptIds
|
||||
* @param provider - provider. No need to provide if connector provided
|
||||
* @param savedObjectsClient - saved objects client
|
||||
*/
|
||||
export const getPrompt = async ({
|
||||
actionsClient,
|
||||
connector,
|
||||
connectorId,
|
||||
model: providedModel,
|
||||
promptGroupId,
|
||||
promptId,
|
||||
provider: providedProvider,
|
||||
savedObjectsClient,
|
||||
}: GetPromptArgs): Promise<string> => {
|
||||
const { provider, model } = await resolveProviderAndModel({
|
||||
providedProvider,
|
||||
providedModel,
|
||||
connectorId,
|
||||
actionsClient,
|
||||
providedConnector: connector,
|
||||
});
|
||||
|
||||
const prompts = await savedObjectsClient.find<Prompt>({
|
||||
type: promptSavedObjectType,
|
||||
filter: `${promptSavedObjectType}.attributes.promptId: "${promptId}" AND ${promptSavedObjectType}.attributes.promptGroupId: "${promptGroupId}"`,
|
||||
fields: ['provider', 'model', 'prompt'],
|
||||
});
|
||||
|
||||
const prompt = findPromptEntry({
|
||||
prompts: prompts?.saved_objects.map((p) => p.attributes) ?? [],
|
||||
promptId,
|
||||
promptGroupId,
|
||||
provider,
|
||||
model,
|
||||
});
|
||||
if (!prompt) {
|
||||
throw new Error(
|
||||
`Prompt not found for promptId: ${promptId} and promptGroupId: ${promptGroupId}`
|
||||
);
|
||||
}
|
||||
|
||||
return prompt;
|
||||
};
|
||||
|
||||
const resolveProviderAndModel = async ({
|
||||
providedProvider,
|
||||
providedModel,
|
||||
connectorId,
|
||||
actionsClient,
|
||||
providedConnector,
|
||||
}: {
|
||||
providedProvider: string | undefined;
|
||||
providedModel: string | undefined;
|
||||
connectorId: string;
|
||||
actionsClient: PublicMethodsOf<ActionsClient>;
|
||||
providedConnector?: Connector;
|
||||
}): Promise<{ provider?: string; model?: string }> => {
|
||||
let model = providedModel;
|
||||
let provider = providedProvider;
|
||||
if (!provider || !model || provider === 'inference') {
|
||||
const connector = providedConnector ?? (await actionsClient.get({ id: connectorId }));
|
||||
|
||||
if (provider === 'inference' && connector.config) {
|
||||
provider = connector.config.provider || provider;
|
||||
model = connector.config.providerConfig?.model_id || model;
|
||||
|
||||
if (provider === 'elastic' && model) {
|
||||
provider = elasticModelDictionary[model]?.provider || 'inference';
|
||||
model = elasticModelDictionary[model]?.model;
|
||||
}
|
||||
} else if (connector.config) {
|
||||
provider = provider || getLlmType(connector.actionTypeId);
|
||||
model = model || connector.config.defaultModel;
|
||||
}
|
||||
}
|
||||
|
||||
return { provider: provider === 'inference' ? 'bedrock' : provider, model };
|
||||
};
|
||||
|
||||
const findPrompt = ({
|
||||
prompts,
|
||||
conditions,
|
||||
}: {
|
||||
prompts: Array<{ provider?: string; model?: string; prompt: { default: string } }>;
|
||||
conditions: Array<(prompt: { provider?: string; model?: string }) => boolean>;
|
||||
}): string | undefined => {
|
||||
for (const condition of conditions) {
|
||||
const match = prompts.find(condition);
|
||||
if (match) return match.prompt.default;
|
||||
}
|
||||
return undefined;
|
||||
};
|
||||
|
||||
const findPromptEntry = ({
|
||||
prompts,
|
||||
promptId,
|
||||
promptGroupId,
|
||||
provider,
|
||||
model,
|
||||
}: {
|
||||
prompts: Prompt[];
|
||||
promptId: string;
|
||||
promptGroupId: string;
|
||||
provider?: string;
|
||||
model?: string;
|
||||
}): string | undefined => {
|
||||
const conditions = [
|
||||
(prompt: { provider?: string; model?: string }) =>
|
||||
prompt.provider === provider && prompt.model === model,
|
||||
(prompt: { provider?: string; model?: string }) =>
|
||||
prompt.provider === provider && !prompt.model,
|
||||
(prompt: { provider?: string; model?: string }) => !prompt.provider && !prompt.model,
|
||||
];
|
||||
|
||||
return (
|
||||
findPrompt({ prompts, conditions }) ??
|
||||
findPrompt({
|
||||
prompts: localPrompts.filter(
|
||||
(p) => p.promptId === promptId && p.promptGroupId === promptGroupId
|
||||
),
|
||||
conditions,
|
||||
})
|
||||
);
|
||||
};
|
|
@ -0,0 +1,20 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
/**
|
||||
* use oss as model when using openai and oss
|
||||
* else default to given model
|
||||
* if no model exists, let undefined and resolveProviderAndModel logic will determine the model from connector
|
||||
* @param llmType
|
||||
* @param isOssModel
|
||||
* @param model
|
||||
*/
|
||||
export const getModelOrOss = (
|
||||
llmType?: string,
|
||||
isOssModel?: boolean,
|
||||
model?: string
|
||||
): string | undefined => (llmType === 'openai' && isOssModel ? 'oss' : model);
|
|
@ -0,0 +1,9 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export { getPrompt, getPromptsByGroupId } from './get_prompt';
|
||||
export { promptDictionary } from './local_prompt_object';
|
|
@ -0,0 +1,157 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Prompt } from './types';
|
||||
import {
|
||||
ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
|
||||
ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
|
||||
ATTACK_DISCOVERY_GENERATION_INSIGHTS,
|
||||
ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
|
||||
ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
|
||||
ATTACK_DISCOVERY_GENERATION_TITLE,
|
||||
ATTACK_DISCOVERY_CONTINUE,
|
||||
ATTACK_DISCOVERY_DEFAULT,
|
||||
ATTACK_DISCOVERY_REFINE,
|
||||
BEDROCK_SYSTEM_PROMPT,
|
||||
DEFAULT_SYSTEM_PROMPT,
|
||||
GEMINI_SYSTEM_PROMPT,
|
||||
GEMINI_USER_PROMPT,
|
||||
STRUCTURED_SYSTEM_PROMPT,
|
||||
} from './prompts';
|
||||
|
||||
export const promptGroupId = {
|
||||
attackDiscovery: 'attackDiscovery',
|
||||
aiAssistant: 'aiAssistant',
|
||||
};
|
||||
|
||||
export const promptDictionary = {
|
||||
systemPrompt: `systemPrompt`,
|
||||
userPrompt: `userPrompt`,
|
||||
attackDiscoveryDefault: `default`,
|
||||
attackDiscoveryRefine: `refine`,
|
||||
attackDiscoveryContinue: `continue`,
|
||||
attackDiscoveryDetailsMarkdown: `detailsMarkdown`,
|
||||
attackDiscoveryEntitySummaryMarkdown: `entitySummaryMarkdown`,
|
||||
attackDiscoveryMitreAttackTactics: `mitreAttackTactics`,
|
||||
attackDiscoverySummaryMarkdown: `summaryMarkdown`,
|
||||
attackDiscoveryGenerationTitle: `generationTitle`,
|
||||
attackDiscoveryGenerationInsights: `generationInsights`,
|
||||
};
|
||||
|
||||
export const localPrompts: Prompt[] = [
|
||||
{
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'openai',
|
||||
prompt: {
|
||||
default: DEFAULT_SYSTEM_PROMPT,
|
||||
},
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
prompt: {
|
||||
default: DEFAULT_SYSTEM_PROMPT,
|
||||
},
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'bedrock',
|
||||
prompt: {
|
||||
default: BEDROCK_SYSTEM_PROMPT,
|
||||
},
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'gemini',
|
||||
prompt: {
|
||||
default: GEMINI_SYSTEM_PROMPT,
|
||||
},
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'openai',
|
||||
model: 'oss',
|
||||
prompt: {
|
||||
default: STRUCTURED_SYSTEM_PROMPT,
|
||||
},
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.userPrompt,
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
provider: 'gemini',
|
||||
prompt: {
|
||||
default: GEMINI_USER_PROMPT,
|
||||
},
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.attackDiscoveryDefault,
|
||||
promptGroupId: promptGroupId.attackDiscovery,
|
||||
prompt: {
|
||||
default: ATTACK_DISCOVERY_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.attackDiscoveryRefine,
|
||||
promptGroupId: promptGroupId.attackDiscovery,
|
||||
prompt: {
|
||||
default: ATTACK_DISCOVERY_REFINE,
|
||||
},
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.attackDiscoveryContinue,
|
||||
promptGroupId: promptGroupId.attackDiscovery,
|
||||
prompt: {
|
||||
default: ATTACK_DISCOVERY_CONTINUE,
|
||||
},
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.attackDiscoveryDetailsMarkdown,
|
||||
promptGroupId: promptGroupId.attackDiscovery,
|
||||
prompt: {
|
||||
default: ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
|
||||
},
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.attackDiscoveryEntitySummaryMarkdown,
|
||||
promptGroupId: promptGroupId.attackDiscovery,
|
||||
prompt: {
|
||||
default: ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
|
||||
},
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.attackDiscoveryMitreAttackTactics,
|
||||
promptGroupId: promptGroupId.attackDiscovery,
|
||||
prompt: {
|
||||
default: ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
|
||||
},
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.attackDiscoverySummaryMarkdown,
|
||||
promptGroupId: promptGroupId.attackDiscovery,
|
||||
prompt: {
|
||||
default: ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
|
||||
},
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.attackDiscoveryGenerationTitle,
|
||||
promptGroupId: promptGroupId.attackDiscovery,
|
||||
prompt: {
|
||||
default: ATTACK_DISCOVERY_GENERATION_TITLE,
|
||||
},
|
||||
},
|
||||
{
|
||||
promptId: promptDictionary.attackDiscoveryGenerationInsights,
|
||||
promptGroupId: promptGroupId.attackDiscovery,
|
||||
prompt: {
|
||||
default: ATTACK_DISCOVERY_GENERATION_INSIGHTS,
|
||||
},
|
||||
},
|
||||
];
|
|
@ -0,0 +1,128 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export const KNOWLEDGE_HISTORY =
|
||||
'If available, use the Knowledge History provided to try and answer the question. If not provided, you can try and query for additional knowledge via the KnowledgeBaseRetrievalTool.';
|
||||
|
||||
export const DEFAULT_SYSTEM_PROMPT = `You are a security analyst and expert in resolving security incidents. Your role is to assist by answering questions about Elastic Security. Do not answer questions unrelated to Elastic Security. ${KNOWLEDGE_HISTORY}`;
|
||||
// system prompt from @afirstenberg
|
||||
const BASE_GEMINI_PROMPT =
|
||||
'You are an assistant that is an expert at using tools and Elastic Security, doing your best to use these tools to answer questions or follow instructions. It is very important to use tools to answer the question or follow the instructions rather than coming up with your own answer. Tool calls are good. Sometimes you may need to make several tool calls to accomplish the task or get an answer to the question that was asked. Use as many tool calls as necessary.';
|
||||
const KB_CATCH =
|
||||
'If the knowledge base tool gives empty results, do your best to answer the question from the perspective of an expert security analyst.';
|
||||
export const GEMINI_SYSTEM_PROMPT = `${BASE_GEMINI_PROMPT} ${KB_CATCH}`;
|
||||
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from NaturalLanguageESQLTool as is. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response.`;
|
||||
export const GEMINI_USER_PROMPT = `Now, always using the tools at your disposal, step by step, come up with a response to this request:\n\n`;
|
||||
|
||||
export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. ${KNOWLEDGE_HISTORY} You have access to the following tools:
|
||||
|
||||
{tools}
|
||||
|
||||
The tool action_input should ALWAYS follow the tool JSON schema args.
|
||||
|
||||
Valid "action" values: "Final Answer" or {tool_names}
|
||||
|
||||
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).
|
||||
|
||||
Provide only ONE action per $JSON_BLOB, as shown:
|
||||
|
||||
\`\`\`
|
||||
|
||||
{{
|
||||
|
||||
"action": $TOOL_NAME,
|
||||
|
||||
"action_input": $TOOL_INPUT
|
||||
|
||||
}}
|
||||
|
||||
\`\`\`
|
||||
|
||||
Follow this format:
|
||||
|
||||
Question: input question to answer
|
||||
|
||||
Thought: consider previous and subsequent steps
|
||||
|
||||
Action:
|
||||
|
||||
\`\`\`
|
||||
|
||||
$JSON_BLOB
|
||||
|
||||
\`\`\`
|
||||
|
||||
Observation: action result
|
||||
|
||||
... (repeat Thought/Action/Observation N times)
|
||||
|
||||
Thought: I know what to respond
|
||||
|
||||
Action:
|
||||
|
||||
\`\`\`
|
||||
|
||||
{{
|
||||
|
||||
"action": "Final Answer",
|
||||
|
||||
"action_input": "Final response to human"}}
|
||||
|
||||
Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`;
|
||||
|
||||
export const ATTACK_DISCOVERY_DEFAULT =
|
||||
"You are a cyber security analyst tasked with analyzing security events from Elastic Security to identify and report on potential cyber attacks or progressions. Your report should focus on high-risk incidents that could severely impact the organization, rather than isolated alerts. Present your findings in a way that can be easily understood by anyone, regardless of their technical expertise, as if you were briefing the CISO. Break down your response into sections based on timing, hosts, and users involved. When correlating alerts, use kibana.alert.original_time when it's available, otherwise use @timestamp. Include appropriate context about the affected hosts and users. Describe how the attack progression might have occurred and, if feasible, attribute it to known threat groups. Prioritize high and critical alerts, but include lower-severity alerts if desired. In the description field, provide as much detail as possible, in a bulleted list explaining any attack progressions. Accuracy is of utmost importance. You MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds).";
|
||||
|
||||
export const ATTACK_DISCOVERY_REFINE = `You previously generated the following insights, but sometimes they represent the same attack.
|
||||
|
||||
Combine the insights below, when they represent the same attack; leave any insights that are not combined unchanged:`;
|
||||
|
||||
export const ATTACK_DISCOVERY_CONTINUE = `Continue exactly where you left off in the JSON output below, generating only the additional JSON output when it's required to complete your work. The additional JSON output MUST ALWAYS follow these rules:
|
||||
1) it MUST conform to the schema above, because it will be checked against the JSON schema
|
||||
2) it MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds), because it will be parsed as JSON
|
||||
3) it MUST NOT repeat any the previous output, because that would prevent partial results from being combined
|
||||
4) it MUST NOT restart from the beginning, because that would prevent partial results from being combined
|
||||
5) it MUST NOT be prefixed or suffixed with additional text outside of the JSON, because that would prevent it from being combined and parsed as JSON:
|
||||
`;
|
||||
|
||||
const SYNTAX = '{{ field.name fieldValue1 fieldValue2 fieldValueN }}';
|
||||
const GOOD_SYNTAX_EXAMPLES =
|
||||
'Examples of CORRECT syntax (includes field names and values): {{ host.name hostNameValue }} {{ user.name userNameValue }} {{ source.ip sourceIpValue }}';
|
||||
|
||||
const BAD_SYNTAX_EXAMPLES =
|
||||
'Examples of INCORRECT syntax (bad, because the field names are not included): {{ hostNameValue }} {{ userNameValue }} {{ sourceIpValue }}';
|
||||
|
||||
const RECONNAISSANCE = 'Reconnaissance';
|
||||
const INITIAL_ACCESS = 'Initial Access';
|
||||
const EXECUTION = 'Execution';
|
||||
const PERSISTENCE = 'Persistence';
|
||||
const PRIVILEGE_ESCALATION = 'Privilege Escalation';
|
||||
const DISCOVERY = 'Discovery';
|
||||
const LATERAL_MOVEMENT = 'Lateral Movement';
|
||||
const COMMAND_AND_CONTROL = 'Command and Control';
|
||||
const EXFILTRATION = 'Exfiltration';
|
||||
|
||||
const MITRE_ATTACK_TACTICS = [
|
||||
RECONNAISSANCE,
|
||||
INITIAL_ACCESS,
|
||||
EXECUTION,
|
||||
PERSISTENCE,
|
||||
PRIVILEGE_ESCALATION,
|
||||
DISCOVERY,
|
||||
LATERAL_MOVEMENT,
|
||||
COMMAND_AND_CONTROL,
|
||||
EXFILTRATION,
|
||||
] as const;
|
||||
export const ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN = `A detailed insight with markdown, where each markdown bullet contains a description of what happened that reads like a story of the attack as it played out and always uses special ${SYNTAX} syntax for field names and values from the source data. ${GOOD_SYNTAX_EXAMPLES} ${BAD_SYNTAX_EXAMPLES}`;
|
||||
export const ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN = `A short (no more than a sentence) summary of the insight featuring only the host.name and user.name fields (when they are applicable), using the same ${SYNTAX} syntax`;
|
||||
export const ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS = `An array of MITRE ATT&CK tactic for the insight, using one of the following values: ${MITRE_ATTACK_TACTICS.join(
|
||||
','
|
||||
)}`;
|
||||
export const ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN = `A markdown summary of insight, using the same ${SYNTAX} syntax`;
|
||||
export const ATTACK_DISCOVERY_GENERATION_TITLE =
|
||||
'A short, no more than 7 words, title for the insight, NOT formatted with special syntax or markdown. This must be as brief as possible.';
|
||||
export const ATTACK_DISCOVERY_GENERATION_INSIGHTS = `Insights with markdown that always uses special ${SYNTAX} syntax for field names and values from the source data. ${GOOD_SYNTAX_EXAMPLES} ${BAD_SYNTAX_EXAMPLES}`;
|
|
@ -0,0 +1,52 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { SavedObjectsType } from '@kbn/core/server';
|
||||
import { promptSavedObjectType } from '../../../common/constants';
|
||||
|
||||
export const promptSavedObjectMappings: SavedObjectsType['mappings'] = {
|
||||
dynamic: false,
|
||||
properties: {
|
||||
description: {
|
||||
type: 'text',
|
||||
},
|
||||
promptId: {
|
||||
// represent unique prompt
|
||||
type: 'keyword',
|
||||
},
|
||||
promptGroupId: {
|
||||
// represents unique groups of prompts
|
||||
type: 'keyword',
|
||||
},
|
||||
provider: {
|
||||
type: 'keyword',
|
||||
},
|
||||
model: {
|
||||
type: 'keyword',
|
||||
},
|
||||
prompt: {
|
||||
properties: {
|
||||
// English is default
|
||||
default: {
|
||||
type: 'text',
|
||||
},
|
||||
// optionally, add ISO 639 two-letter language code to support more translations
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
export const promptType: SavedObjectsType = {
|
||||
name: promptSavedObjectType,
|
||||
hidden: false,
|
||||
management: {
|
||||
importableAndExportable: true,
|
||||
visibleInManagement: false,
|
||||
},
|
||||
namespaceType: 'agnostic',
|
||||
mappings: promptSavedObjectMappings,
|
||||
};
|
|
@ -0,0 +1,17 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
export interface Prompt {
|
||||
promptId: string;
|
||||
promptGroupId: string;
|
||||
prompt: {
|
||||
default: string;
|
||||
};
|
||||
provider?: string;
|
||||
model?: string;
|
||||
description?: string;
|
||||
}
|
|
@ -10,6 +10,7 @@ import { PluginInitializerContext, CoreStart, Plugin, Logger } from '@kbn/core/s
|
|||
import { AssistantFeatures } from '@kbn/elastic-assistant-common';
|
||||
import { ReplaySubject, type Subject } from 'rxjs';
|
||||
import { MlPluginSetup } from '@kbn/ml-plugin/server';
|
||||
import { initSavedObjects } from './saved_objects';
|
||||
import { events } from './lib/telemetry/event_based_telemetry';
|
||||
import {
|
||||
AssistantTool,
|
||||
|
@ -55,6 +56,8 @@ export class ElasticAssistantPlugin
|
|||
) {
|
||||
this.logger.debug('elasticAssistant: Setup');
|
||||
|
||||
initSavedObjects(core.savedObjects);
|
||||
|
||||
this.assistantService = new AIAssistantService({
|
||||
logger: this.logger.get('service'),
|
||||
ml: plugins.ml,
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
import type { ActionsClient } from '@kbn/actions-plugin/server';
|
||||
import { ElasticsearchClient } from '@kbn/core-elasticsearch-server';
|
||||
import { Logger } from '@kbn/core/server';
|
||||
import { Logger, SavedObjectsClientContract } from '@kbn/core/server';
|
||||
import { ApiConfig, AttackDiscovery, Replacements } from '@kbn/elastic-assistant-common';
|
||||
import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen';
|
||||
import { ActionsClientLlm } from '@kbn/langchain/server';
|
||||
|
@ -23,6 +23,7 @@ import {
|
|||
import { GraphState } from '../../../../../lib/attack_discovery/graphs/default_attack_discovery_graph/types';
|
||||
import { throwIfErrorCountsExceeded } from '../throw_if_error_counts_exceeded';
|
||||
import { getLlmType } from '../../../../utils';
|
||||
import { getAttackDiscoveryPrompts } from '../../../../../lib/attack_discovery/graphs/default_attack_discovery_graph/nodes/helpers/prompts';
|
||||
|
||||
export const invokeAttackDiscoveryGraph = async ({
|
||||
actionsClient,
|
||||
|
@ -38,6 +39,7 @@ export const invokeAttackDiscoveryGraph = async ({
|
|||
latestReplacements,
|
||||
logger,
|
||||
onNewReplacements,
|
||||
savedObjectsClient,
|
||||
size,
|
||||
start,
|
||||
}: {
|
||||
|
@ -54,6 +56,7 @@ export const invokeAttackDiscoveryGraph = async ({
|
|||
latestReplacements: Replacements;
|
||||
logger: Logger;
|
||||
onNewReplacements: (newReplacements: Replacements) => void;
|
||||
savedObjectsClient: SavedObjectsClientContract;
|
||||
start?: string;
|
||||
size: number;
|
||||
}): Promise<{
|
||||
|
@ -89,6 +92,15 @@ export const invokeAttackDiscoveryGraph = async ({
|
|||
throw new Error('LLM is required for attack discoveries');
|
||||
}
|
||||
|
||||
const attackDiscoveryPrompts = await getAttackDiscoveryPrompts({
|
||||
actionsClient,
|
||||
connectorId: apiConfig.connectorId,
|
||||
// if in future oss has different prompt, add it as model here
|
||||
model,
|
||||
provider: llmType,
|
||||
savedObjectsClient,
|
||||
});
|
||||
|
||||
const graph = getDefaultAttackDiscoveryGraph({
|
||||
alertsIndexPattern,
|
||||
anonymizationFields,
|
||||
|
@ -98,6 +110,7 @@ export const invokeAttackDiscoveryGraph = async ({
|
|||
llm,
|
||||
logger,
|
||||
onNewReplacements,
|
||||
prompts: attackDiscoveryPrompts,
|
||||
replacements: latestReplacements,
|
||||
size,
|
||||
start,
|
||||
|
|
|
@ -66,6 +66,7 @@ export const postAttackDiscoveryRoute = (
|
|||
const assistantContext = await context.elasticAssistant;
|
||||
const logger: Logger = assistantContext.logger;
|
||||
const telemetry = assistantContext.telemetry;
|
||||
const savedObjectsClient = assistantContext.savedObjectsClient;
|
||||
|
||||
try {
|
||||
// get the actions plugin start contract from the request context:
|
||||
|
@ -144,6 +145,7 @@ export const postAttackDiscoveryRoute = (
|
|||
latestReplacements,
|
||||
logger,
|
||||
onNewReplacements,
|
||||
savedObjectsClient,
|
||||
size,
|
||||
start,
|
||||
})
|
||||
|
|
|
@ -106,6 +106,7 @@ export const chatCompleteRoute = (
|
|||
const connector = connectors.length > 0 ? connectors[0] : undefined;
|
||||
actionTypeId = connector?.actionTypeId ?? '.gen-ai';
|
||||
const isOssModel = isOpenSourceModel(connector);
|
||||
const savedObjectsClient = ctx.elasticAssistant.savedObjectsClient;
|
||||
|
||||
// replacements
|
||||
const anonymizationFieldsRes =
|
||||
|
@ -221,6 +222,7 @@ export const chatCompleteRoute = (
|
|||
response,
|
||||
telemetry,
|
||||
responseLanguage: request.body.responseLanguage,
|
||||
savedObjectsClient,
|
||||
...(productDocsAvailable ? { llmTasks: ctx.elasticAssistant.llmTasks } : {}),
|
||||
});
|
||||
} catch (err) {
|
||||
|
|
|
@ -30,6 +30,14 @@ import {
|
|||
createToolCallingAgent,
|
||||
} from 'langchain/agents';
|
||||
import { omit } from 'lodash/fp';
|
||||
import { promptGroupId } from '../../lib/prompt/local_prompt_object';
|
||||
import { getModelOrOss } from '../../lib/prompt/helpers';
|
||||
import { getAttackDiscoveryPrompts } from '../../lib/attack_discovery/graphs/default_attack_discovery_graph/nodes/helpers/prompts';
|
||||
import {
|
||||
formatPrompt,
|
||||
formatPromptStructured,
|
||||
} from '../../lib/langchain/graphs/default_assistant_graph/prompts';
|
||||
import { getPrompt, promptDictionary } from '../../lib/prompt';
|
||||
import { buildResponse } from '../../lib/build_response';
|
||||
import { AssistantDataClients } from '../../lib/langchain/executors/types';
|
||||
import { AssistantToolParams, ElasticAssistantRequestHandlerContext, GetElser } from '../../types';
|
||||
|
@ -42,12 +50,6 @@ import {
|
|||
DefaultAssistantGraph,
|
||||
getDefaultAssistantGraph,
|
||||
} from '../../lib/langchain/graphs/default_assistant_graph/graph';
|
||||
import {
|
||||
bedrockToolCallingAgentPrompt,
|
||||
geminiToolCallingAgentPrompt,
|
||||
openAIFunctionAgentPrompt,
|
||||
structuredChatAgentPrompt,
|
||||
} from '../../lib/langchain/graphs/default_assistant_graph/prompts';
|
||||
import { getLlmClass, getLlmType, isOpenSourceModel } from '../utils';
|
||||
import { getGraphsFromNames } from './get_graphs_from_names';
|
||||
|
||||
|
@ -95,6 +97,7 @@ export const postEvaluateRoute = (
|
|||
const actions = ctx.elasticAssistant.actions;
|
||||
const logger = assistantContext.logger.get('evaluate');
|
||||
const abortSignal = getRequestAbortedSignal(request.events.aborted$);
|
||||
const savedObjectsClient = ctx.elasticAssistant.savedObjectsClient;
|
||||
|
||||
// Perform license, authenticated user and evaluation FF checks
|
||||
const checkResponse = performChecks({
|
||||
|
@ -176,6 +179,20 @@ export const postEvaluateRoute = (
|
|||
ids: connectorIds,
|
||||
throwIfSystemAction: false,
|
||||
});
|
||||
const connectorsWithPrompts = await Promise.all(
|
||||
connectors.map(async (connector) => {
|
||||
const prompts = await getAttackDiscoveryPrompts({
|
||||
actionsClient,
|
||||
connectorId: connector.id,
|
||||
connector,
|
||||
savedObjectsClient,
|
||||
});
|
||||
return {
|
||||
...connector,
|
||||
prompts,
|
||||
};
|
||||
})
|
||||
);
|
||||
|
||||
// Fetch any tools registered to the security assistant
|
||||
const assistantTools = assistantContext.getRegisteredTools(DEFAULT_PLUGIN_NAME);
|
||||
|
@ -190,7 +207,7 @@ export const postEvaluateRoute = (
|
|||
actionsClient,
|
||||
alertsIndexPattern,
|
||||
attackDiscoveryGraphs,
|
||||
connectors,
|
||||
connectors: connectorsWithPrompts,
|
||||
connectorTimeout: CONNECTOR_TIMEOUT,
|
||||
datasetName,
|
||||
esClient,
|
||||
|
@ -217,6 +234,7 @@ export const postEvaluateRoute = (
|
|||
graph: DefaultAssistantGraph;
|
||||
llmType: string | undefined;
|
||||
isOssModel: boolean | undefined;
|
||||
connectorId: string;
|
||||
}> = await Promise.all(
|
||||
connectors.map(async (connector) => {
|
||||
const llmType = getLlmType(connector.actionTypeId);
|
||||
|
@ -293,31 +311,40 @@ export const postEvaluateRoute = (
|
|||
(tool) => tool.getTool(assistantToolParams) ?? []
|
||||
);
|
||||
|
||||
const defaultSystemPrompt = await getPrompt({
|
||||
actionsClient,
|
||||
connector,
|
||||
connectorId: connector.id,
|
||||
model: getModelOrOss(llmType, isOssModel),
|
||||
promptGroupId: promptGroupId.aiAssistant,
|
||||
promptId: promptDictionary.systemPrompt,
|
||||
provider: llmType,
|
||||
savedObjectsClient,
|
||||
});
|
||||
|
||||
const agentRunnable = isOpenAI
|
||||
? await createOpenAIFunctionsAgent({
|
||||
llm,
|
||||
tools,
|
||||
prompt: openAIFunctionAgentPrompt,
|
||||
prompt: formatPrompt(defaultSystemPrompt),
|
||||
streamRunnable: false,
|
||||
})
|
||||
: llmType && ['bedrock', 'gemini'].includes(llmType)
|
||||
? createToolCallingAgent({
|
||||
llm,
|
||||
tools,
|
||||
prompt:
|
||||
llmType === 'bedrock'
|
||||
? bedrockToolCallingAgentPrompt
|
||||
: geminiToolCallingAgentPrompt,
|
||||
prompt: formatPrompt(defaultSystemPrompt),
|
||||
streamRunnable: false,
|
||||
})
|
||||
: await createStructuredChatAgent({
|
||||
llm,
|
||||
tools,
|
||||
prompt: structuredChatAgentPrompt,
|
||||
prompt: formatPromptStructured(defaultSystemPrompt),
|
||||
streamRunnable: false,
|
||||
});
|
||||
|
||||
return {
|
||||
connectorId: connector.id,
|
||||
name: `${runName} - ${connector.name}`,
|
||||
llmType,
|
||||
isOssModel,
|
||||
|
@ -326,6 +353,8 @@ export const postEvaluateRoute = (
|
|||
dataClients,
|
||||
createLlmInstance,
|
||||
logger,
|
||||
actionsClient,
|
||||
savedObjectsClient,
|
||||
tools,
|
||||
replacements: {},
|
||||
}),
|
||||
|
@ -334,7 +363,7 @@ export const postEvaluateRoute = (
|
|||
);
|
||||
|
||||
// Run an evaluation for each graph so they show up separately (resulting in each dataset run grouped by connector)
|
||||
await asyncForEach(graphs, async ({ name, graph, llmType, isOssModel }) => {
|
||||
await asyncForEach(graphs, async ({ name, graph, llmType, isOssModel, connectorId }) => {
|
||||
// Wrapper function for invoking the graph (to parse different input/output formats)
|
||||
const predict = async (input: { input: string }) => {
|
||||
logger.debug(`input:\n ${JSON.stringify(input, null, 2)}`);
|
||||
|
@ -342,6 +371,7 @@ export const postEvaluateRoute = (
|
|||
const r = await graph.invoke(
|
||||
{
|
||||
input: input.input,
|
||||
connectorId,
|
||||
conversationId: undefined,
|
||||
responseLanguage: 'English',
|
||||
llmType,
|
||||
|
|
|
@ -12,6 +12,7 @@ import {
|
|||
KibanaRequest,
|
||||
KibanaResponseFactory,
|
||||
Logger,
|
||||
SavedObjectsClientContract,
|
||||
} from '@kbn/core/server';
|
||||
import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server';
|
||||
|
||||
|
@ -235,6 +236,7 @@ export interface LangChainExecuteParams {
|
|||
getElser: GetElser;
|
||||
response: KibanaResponseFactory;
|
||||
responseLanguage?: string;
|
||||
savedObjectsClient: SavedObjectsClientContract;
|
||||
systemPrompt?: string;
|
||||
}
|
||||
export const langChainExecute = async ({
|
||||
|
@ -258,6 +260,7 @@ export const langChainExecute = async ({
|
|||
response,
|
||||
responseLanguage,
|
||||
isStream = true,
|
||||
savedObjectsClient,
|
||||
systemPrompt,
|
||||
}: LangChainExecuteParams) => {
|
||||
// Fetch any tools registered by the request's originating plugin
|
||||
|
@ -316,6 +319,7 @@ export const langChainExecute = async ({
|
|||
request,
|
||||
replacements,
|
||||
responseLanguage,
|
||||
savedObjectsClient,
|
||||
size: request.body.size,
|
||||
systemPrompt,
|
||||
telemetry,
|
||||
|
|
|
@ -99,6 +99,7 @@ export const postActionsConnectorExecuteRoute = (
|
|||
// get the actions plugin start contract from the request context:
|
||||
const actions = ctx.elasticAssistant.actions;
|
||||
const inference = ctx.elasticAssistant.inference;
|
||||
const savedObjectsClient = ctx.elasticAssistant.savedObjectsClient;
|
||||
const productDocsAvailable =
|
||||
(await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false;
|
||||
const actionsClient = await actions.getActionsClientWithRequest(request);
|
||||
|
@ -153,6 +154,7 @@ export const postActionsConnectorExecuteRoute = (
|
|||
request,
|
||||
response,
|
||||
telemetry,
|
||||
savedObjectsClient,
|
||||
systemPrompt,
|
||||
...(productDocsAvailable ? { llmTasks: ctx.elasticAssistant.llmTasks } : {}),
|
||||
});
|
||||
|
|
|
@ -80,7 +80,7 @@ export class RequestContextFactory implements IRequestContextFactory {
|
|||
},
|
||||
llmTasks: startPlugins.llmTasks,
|
||||
inference: startPlugins.inference,
|
||||
|
||||
savedObjectsClient: coreStart.savedObjects.getScopedClient(request),
|
||||
telemetry: core.analytics,
|
||||
|
||||
// Note: modelIdOverride is used here to enable setting up the KB using a different ELSER model, which
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { CoreSetup } from '@kbn/core/server';
|
||||
import { promptType } from './lib/prompt/saved_object_mappings';
|
||||
|
||||
export const initSavedObjects = (savedObjects: CoreSetup['savedObjects']) => {
|
||||
try {
|
||||
savedObjects.registerType(promptType);
|
||||
} catch (e) {
|
||||
// implementation intends to fall back to reasonable defaults when the saved objects are unavailable
|
||||
// do not block the plugin from starting
|
||||
}
|
||||
};
|
|
@ -9,23 +9,24 @@ import type {
|
|||
PluginSetupContract as ActionsPluginSetup,
|
||||
PluginStartContract as ActionsPluginStart,
|
||||
} from '@kbn/actions-plugin/server';
|
||||
import type {
|
||||
import {
|
||||
AuthenticatedUser,
|
||||
CoreRequestHandlerContext,
|
||||
CoreSetup,
|
||||
AnalyticsServiceSetup,
|
||||
CustomRequestHandlerContext,
|
||||
ElasticsearchClient,
|
||||
IRouter,
|
||||
KibanaRequest,
|
||||
Logger,
|
||||
AuditLogger,
|
||||
SavedObjectsClientContract,
|
||||
} from '@kbn/core/server';
|
||||
import type { LlmTasksPluginStart } from '@kbn/llm-tasks-plugin/server';
|
||||
import { type MlPluginSetup } from '@kbn/ml-plugin/server';
|
||||
import { DynamicStructuredTool, Tool } from '@langchain/core/tools';
|
||||
import { SpacesPluginSetup, SpacesPluginStart } from '@kbn/spaces-plugin/server';
|
||||
import { TaskManagerSetupContract } from '@kbn/task-manager-plugin/server';
|
||||
import { ElasticsearchClient } from '@kbn/core/server';
|
||||
import {
|
||||
AttackDiscoveryPostRequestBody,
|
||||
DefendInsightsPostRequestBody,
|
||||
|
@ -140,6 +141,7 @@ export interface ElasticAssistantApiRequestHandlerContext {
|
|||
getAIAssistantAnonymizationFieldsDataClient: () => Promise<AIAssistantDataClient | null>;
|
||||
llmTasks: LlmTasksPluginStart;
|
||||
inference: InferenceServerStart;
|
||||
savedObjectsClient: SavedObjectsClientContract;
|
||||
telemetry: AnalyticsServiceSetup;
|
||||
}
|
||||
/**
|
||||
|
|
|
@ -53,6 +53,8 @@
|
|||
"@kbn/core-analytics-server",
|
||||
"@kbn/llm-tasks-plugin",
|
||||
"@kbn/product-doc-base-plugin",
|
||||
"@kbn/core-saved-objects-api-server-mocks",
|
||||
"@kbn/inference-common"
|
||||
],
|
||||
"exclude": [
|
||||
"target/**/*",
|
||||
|
|
|
@ -171,6 +171,7 @@ export const generateFleetPackageInfo = (): PackageInfo => {
|
|||
ml_module: [],
|
||||
security_rule: [],
|
||||
tag: [],
|
||||
security_ai_prompt: [],
|
||||
osquery_pack_asset: [],
|
||||
osquery_saved_query: [],
|
||||
},
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue