[Security AI] Add Kibana Support for Security AI Prompts Integration (#207138)

This commit is contained in:
Steph Milovic 2025-01-28 15:35:39 -07:00 committed by GitHub
parent b998946003
commit 7af5a8338b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
86 changed files with 1920 additions and 398 deletions

View file

@ -914,6 +914,15 @@
"username"
],
"search-telemetry": [],
"security-ai-prompt": [
"description",
"model",
"prompt",
"prompt.default",
"promptGroupId",
"promptId",
"provider"
],
"security-rule": [
"rule_id",
"version"

View file

@ -3023,6 +3023,33 @@
"dynamic": false,
"properties": {}
},
"security-ai-prompt": {
"dynamic": false,
"properties": {
"description": {
"type": "text"
},
"model": {
"type": "keyword"
},
"prompt": {
"properties": {
"default": {
"type": "text"
}
}
},
"promptGroupId": {
"type": "keyword"
},
"promptId": {
"type": "keyword"
},
"provider": {
"type": "keyword"
}
}
},
"security-rule": {
"dynamic": false,
"properties": {

View file

@ -11,4 +11,4 @@ export { registerCoreObjectTypes } from './registration';
// set minimum number of registered saved objects to ensure no object types are removed after 8.8
// declared in internal implementation exclicilty to prevent unintended changes.
export const SAVED_OBJECT_TYPES_COUNT = 127 as const;
export const SAVED_OBJECT_TYPES_COUNT = 128 as const;

View file

@ -157,6 +157,7 @@ describe('checking migration metadata changes on all registered SO types', () =>
"search": "0aa6eefb37edd3145be340a8b67779c2ca578b22",
"search-session": "b2fcd840e12a45039ada50b1355faeafa39876d1",
"search-telemetry": "b568601618744720b5662946d3103e3fb75fe8ee",
"security-ai-prompt": "cc8ee5aaa9d001e89c131bbd5af6bc80bc271046",
"security-rule": "07abb4d7e707d91675ec0495c73816394c7b521f",
"security-solution-signals-migration": "9d99715fe5246f19de2273ba77debd2446c36bb1",
"siem-detection-engine-rule-actions": "54f08e23887b20da7c805fab7c60bc67c428aff9",

View file

@ -123,6 +123,7 @@ const previouslyRegisteredTypes = [
'search',
'search-session',
'search-telemetry',
'security-ai-prompt',
'security-rule',
'security-solution-signals-migration',
'risk-engine-configuration',

View file

@ -96,6 +96,7 @@ export {
isInferenceRequestError,
isInferenceRequestAbortedError,
} from './src/errors';
export { elasticModelDictionary } from './src/const';
export { truncateList } from './src/truncate_list';
export {

View file

@ -0,0 +1,15 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { ElasticModelDictionary } from './types';
export const elasticModelDictionary: ElasticModelDictionary = {
'rainbow-sprinkles': {
provider: 'bedrock',
model: 'us.anthropic.claude-3-5-sonnet-20240620-v1:0',
},
};

View file

@ -0,0 +1,13 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export interface ElasticModelDictionary {
[key: string]: {
provider: string;
model: string;
};
}

View file

@ -251,6 +251,7 @@ export const item: GetInfoResponse['item'] = {
index_pattern: [],
lens: [],
map: [],
security_ai_prompt: [],
security_rule: [],
csp_rule_template: [],
tag: [],

View file

@ -106,6 +106,7 @@ export const item: GetInfoResponse['item'] = {
ml_module: [],
osquery_pack_asset: [],
osquery_saved_query: [],
security_ai_prompt: [],
security_rule: [],
csp_rule_template: [],
tag: [],

View file

@ -33,6 +33,7 @@ describe('Fleet - packageToPackagePolicy', () => {
map: [],
lens: [],
ml_module: [],
security_ai_prompt: [],
security_rule: [],
tag: [],
osquery_pack_asset: [],

View file

@ -59,6 +59,7 @@ export enum KibanaAssetType {
indexPattern = 'index_pattern',
map = 'map',
mlModule = 'ml_module',
securityAIPrompt = 'security_ai_prompt',
securityRule = 'security_rule',
cloudSecurityPostureRuleTemplate = 'csp_rule_template',
osqueryPackAsset = 'osquery_pack_asset',
@ -77,6 +78,7 @@ export enum KibanaSavedObjectType {
indexPattern = 'index-pattern',
map = 'map',
mlModule = 'ml-module',
securityAIPrompt = 'security-ai-prompt',
securityRule = 'security-rule',
cloudSecurityPostureRuleTemplate = 'csp-rule-template',
osqueryPackAsset = 'osquery-pack-asset',

View file

@ -47,6 +47,12 @@ export const AssetTitleMap: Record<
map: i18n.translate('xpack.fleet.epm.assetTitles.maps', {
defaultMessage: 'Maps',
}),
'security-ai-prompt': i18n.translate('xpack.fleet.epm.assetTitles.securityAIPrompt', {
defaultMessage: 'Security AI prompt',
}),
security_ai_prompt: i18n.translate('xpack.fleet.epm.assetTitles.securityAIPrompt', {
defaultMessage: 'Security AI prompt',
}),
'security-rule': i18n.translate('xpack.fleet.epm.assetTitles.securityRules', {
defaultMessage: 'Security rules',
}),

View file

@ -193,6 +193,7 @@ describe('schema validation', () => {
map: [],
index_pattern: [],
ml_module: [],
security_ai_prompt: [],
security_rule: [],
tag: [],
csp_rule_template: [],

View file

@ -38,6 +38,7 @@ packageInfoCache.set('test_package-0.0.0', {
index_pattern: [],
map: [],
lens: [],
security_ai_prompt: [],
security_rule: [],
ml_module: [],
tag: [],
@ -122,6 +123,7 @@ packageInfoCache.set('osquery_manager-0.3.0', {
index_pattern: [],
map: [],
lens: [],
security_ai_prompt: [],
security_rule: [],
ml_module: [],
tag: [],
@ -172,6 +174,7 @@ packageInfoCache.set('profiler_symbolizer-8.8.0-preview', {
index_pattern: [],
map: [],
lens: [],
security_ai_prompt: [],
security_rule: [],
ml_module: [],
tag: [],
@ -222,6 +225,7 @@ packageInfoCache.set('profiler_collector-8.9.0-preview', {
index_pattern: [],
map: [],
lens: [],
security_ai_prompt: [],
security_rule: [],
ml_module: [],
tag: [],
@ -264,6 +268,7 @@ packageInfoCache.set('apm-8.9.0-preview', {
index_pattern: [],
map: [],
lens: [],
security_ai_prompt: [],
security_rule: [],
ml_module: [],
tag: [],
@ -306,6 +311,7 @@ packageInfoCache.set('elastic_connectors-1.0.0', {
index_pattern: [],
map: [],
lens: [],
security_ai_prompt: [],
security_rule: [],
ml_module: [],
tag: [],

View file

@ -33,9 +33,10 @@ import { deleteKibanaSavedObjectsAssets } from '../../packages/remove';
import { FleetError, KibanaSOReferenceError } from '../../../../errors';
import { withPackageSpan } from '../../packages/utils';
import { appContextService } from '../../..';
import { tagKibanaAssets } from './tag_assets';
import { getSpaceAwareSaveobjectsClients } from './saved_objects';
import { appContextService } from '../../..';
const MAX_ASSETS_TO_INSTALL_IN_PARALLEL = 1000;
@ -70,6 +71,7 @@ export const KibanaSavedObjectTypeMapping: Record<KibanaAssetType, KibanaSavedOb
[KibanaAssetType.visualization]: KibanaSavedObjectType.visualization,
[KibanaAssetType.lens]: KibanaSavedObjectType.lens,
[KibanaAssetType.mlModule]: KibanaSavedObjectType.mlModule,
[KibanaAssetType.securityAIPrompt]: KibanaSavedObjectType.securityAIPrompt,
[KibanaAssetType.securityRule]: KibanaSavedObjectType.securityRule,
[KibanaAssetType.cloudSecurityPostureRuleTemplate]:
KibanaSavedObjectType.cloudSecurityPostureRuleTemplate,

View file

@ -31,3 +31,6 @@ export const CAPABILITIES = `${BASE_PATH}/capabilities`;
Licensing requirements
*/
export const MINIMUM_AI_ASSISTANT_LICENSE = 'enterprise' as const;
// Saved Objects
export const promptSavedObjectType = 'security-ai-prompt';

View file

@ -18,6 +18,19 @@ import type { Logger } from '@kbn/logging';
import { ChatPromptTemplate } from '@langchain/core/prompts';
import { FakeLLM } from '@langchain/core/utils/testing';
import { createOpenAIFunctionsAgent } from 'langchain/agents';
import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock';
import { savedObjectsClientMock } from '@kbn/core/server/mocks';
import {
ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
ATTACK_DISCOVERY_GENERATION_INSIGHTS,
ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
ATTACK_DISCOVERY_GENERATION_TITLE,
ATTACK_DISCOVERY_CONTINUE,
ATTACK_DISCOVERY_DEFAULT,
ATTACK_DISCOVERY_REFINE,
} from '../server/lib/prompt/prompts';
import { getDefaultAssistantGraph } from '../server/lib/langchain/graphs/default_assistant_graph/graph';
import { getDefaultAttackDiscoveryGraph } from '../server/lib/attack_discovery/graphs/default_attack_discovery_graph';
@ -49,11 +62,13 @@ async function getAssistantGraph(logger: Logger): Promise<Drawable> {
streamRunnable: false,
});
const graph = getDefaultAssistantGraph({
actionsClient: actionsClientMock.create(),
agentRunnable,
logger,
createLlmInstance,
tools: [],
replacements: {},
savedObjectsClient: savedObjectsClientMock.create(),
});
return graph.getGraph();
}
@ -67,6 +82,17 @@ async function getAttackDiscoveryGraph(logger: Logger): Promise<Drawable> {
llm: mockLlm as unknown as ActionsClientLlm,
logger,
replacements: {},
prompts: {
default: ATTACK_DISCOVERY_DEFAULT,
refine: ATTACK_DISCOVERY_REFINE,
continue: ATTACK_DISCOVERY_CONTINUE,
detailsMarkdown: ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
entitySummaryMarkdown: ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
mitreAttackTactics: ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
summaryMarkdown: ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
title: ATTACK_DISCOVERY_GENERATION_TITLE,
insights: ATTACK_DISCOVERY_GENERATION_INSIGHTS,
},
size: 20,
});

View file

@ -54,8 +54,8 @@ export const createMockClients = () => {
getCurrentUser: jest.fn(),
inference: jest.fn(),
llmTasks: jest.fn(),
savedObjectsClient: core.savedObjects.client,
},
savedObjectsClient: core.savedObjects.client,
licensing: {
...licensingMock.createRequestHandlerContext({ license }),
@ -148,6 +148,7 @@ const createElasticAssistantRequestContextMock = (
inference: { getClient: jest.fn() },
llmTasks: { retrieveDocumentationAvailable: jest.fn(), retrieveDocumentation: jest.fn() },
core: clients.core,
savedObjectsClient: clients.elasticAssistant.savedObjectsClient,
telemetry: clients.elasticAssistant.telemetry,
};
};

View file

@ -4,7 +4,6 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { ActionsClient } from '@kbn/actions-plugin/server';
import type { Connector } from '@kbn/actions-plugin/server/application/connector/types';
import { elasticsearchServiceMock } from '@kbn/core-elasticsearch-server-mocks';
@ -46,7 +45,6 @@ jest.mock('./helpers/get_evaluator_llm', () => {
getEvaluatorLlm: jest.fn().mockResolvedValue(mockLlm),
};
});
const actionsClient = {
get: jest.fn(),
} as unknown as ActionsClient;
@ -61,7 +59,22 @@ const logger = loggerMock.create();
const mockEsClient = elasticsearchServiceMock.createElasticsearchClient();
const runName = 'test-run-name';
const connectors = [mockExperimentConnector];
const connectors = [
{
...mockExperimentConnector,
prompts: {
default: '',
refine: '',
continue: '',
detailsMarkdown: '',
entitySummaryMarkdown: '',
mitreAttackTactics: '',
summaryMarkdown: '',
title: '',
insights: '',
},
},
];
const projectName = 'test-lang-smith-project';

View file

@ -16,12 +16,16 @@ import { getLangSmithTracer } from '@kbn/langchain/server/tracers/langsmith';
import { asyncForEach } from '@kbn/std';
import { PublicMethodsOf } from '@kbn/utility-types';
import { CombinedPrompts } from '../graphs/default_attack_discovery_graph/nodes/helpers/prompts';
import { DEFAULT_EVAL_ANONYMIZATION_FIELDS } from './constants';
import { AttackDiscoveryGraphMetadata } from '../../langchain/graphs';
import { DefaultAttackDiscoveryGraph } from '../graphs/default_attack_discovery_graph';
import { getLlmType } from '../../../routes/utils';
import { runEvaluations } from './run_evaluations';
interface ConnectorWithPrompts extends Connector {
prompts: CombinedPrompts;
}
export const evaluateAttackDiscovery = async ({
actionsClient,
attackDiscoveryGraphs,
@ -43,7 +47,7 @@ export const evaluateAttackDiscovery = async ({
attackDiscoveryGraphs: AttackDiscoveryGraphMetadata[];
alertsIndexPattern: string;
anonymizationFields?: AnonymizationFieldResponse[];
connectors: Connector[];
connectors: ConnectorWithPrompts[];
connectorTimeout: number;
datasetName: string;
esClient: ElasticsearchClient;
@ -96,6 +100,7 @@ export const evaluateAttackDiscovery = async ({
esClient,
llm,
logger,
prompts: connector.prompts,
size,
});

View file

@ -35,6 +35,7 @@ const graphState: GraphState = {
],
combinedGenerations: 'generations',
combinedRefinements: 'refinements',
continuePrompt: 'continue',
errors: [],
generationAttempts: 0,
generations: [],

View file

@ -35,6 +35,7 @@ const graphState: GraphState = {
],
combinedGenerations: '',
combinedRefinements: '',
continuePrompt: 'continue',
errors: [],
generationAttempts: 0,
generations: [],

View file

@ -23,6 +23,7 @@ const initialGraphState: GraphState = {
anonymizedAlerts: [...mockAnonymizedAlerts],
combinedGenerations: 'generations',
combinedRefinements: '',
continuePrompt: 'continue',
errors: [],
generationAttempts: 2,
generations: ['gen', 'erations'],

View file

@ -19,6 +19,7 @@ const initialGraphState: GraphState = {
anonymizedAlerts: [],
combinedGenerations: '',
combinedRefinements: '',
continuePrompt: 'continue',
errors: [],
generationAttempts: 0,
generations: [],

View file

@ -5,13 +5,14 @@
* 2.0.
*/
import type { ElasticsearchClient, Logger } from '@kbn/core/server';
import { ElasticsearchClient, Logger } from '@kbn/core/server';
import { Replacements } from '@kbn/elastic-assistant-common';
import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen';
import type { ActionsClientLlm } from '@kbn/langchain/server';
import type { CompiledStateGraph } from '@langchain/langgraph';
import { END, START, StateGraph } from '@langchain/langgraph';
import { CombinedPrompts } from './nodes/helpers/prompts';
import { NodeType } from './constants';
import { getGenerateOrEndEdge } from './edges/generate_or_end';
import { getGenerateOrRefineOrEndEdge } from './edges/generate_or_refine_or_end';
@ -32,6 +33,7 @@ export interface GetDefaultAttackDiscoveryGraphParams {
llm: ActionsClientLlm;
logger?: Logger;
onNewReplacements?: (replacements: Replacements) => void;
prompts: CombinedPrompts;
replacements?: Replacements;
size: number;
start?: string;
@ -55,6 +57,7 @@ export const getDefaultAttackDiscoveryGraph = ({
llm,
logger,
onNewReplacements,
prompts,
replacements,
size,
start,
@ -64,7 +67,7 @@ export const getDefaultAttackDiscoveryGraph = ({
'generate' | 'refine' | 'retrieve_anonymized_alerts' | '__start__'
> => {
try {
const graphState = getDefaultGraphState({ end, filter, start });
const graphState = getDefaultGraphState({ end, filter, prompts, start });
// get nodes:
const retrieveAnonymizedAlertsNode = getRetrieveAnonymizedAlertsNode({
@ -80,11 +83,13 @@ export const getDefaultAttackDiscoveryGraph = ({
const generateNode = getGenerateNode({
llm,
logger,
prompts,
});
const refineNode = getRefineNode({
llm,
logger,
prompts,
});
// get edges:

View file

@ -31,6 +31,7 @@ const graphState: GraphState = {
],
combinedGenerations: 'combinedGenerations',
combinedRefinements: '',
continuePrompt: 'continue',
errors: [],
generationAttempts: 2,
generations: ['combined', 'Generations'],

View file

@ -6,13 +6,13 @@
*/
import { getAlertsContextPrompt } from '.';
import { getDefaultAttackDiscoveryPrompt } from '../../../helpers/get_default_attack_discovery_prompt';
import { ATTACK_DISCOVERY_DEFAULT } from '../../../../../../../prompt/prompts';
describe('getAlertsContextPrompt', () => {
it('generates the correct prompt', () => {
const anonymizedAlerts = ['Alert 1', 'Alert 2', 'Alert 3'];
const expected = `${getDefaultAttackDiscoveryPrompt()}
const expected = `${ATTACK_DISCOVERY_DEFAULT}
Use context from the following alerts to provide insights:
@ -27,7 +27,7 @@ Alert 3
const prompt = getAlertsContextPrompt({
anonymizedAlerts,
attackDiscoveryPrompt: getDefaultAttackDiscoveryPrompt(),
attackDiscoveryPrompt: ATTACK_DISCOVERY_DEFAULT,
});
expect(prompt).toEqual(expected);

View file

@ -16,6 +16,7 @@ const graphState: GraphState = {
anonymizedAlerts: mockAnonymizedAlerts, // <-- mockAnonymizedAlerts is an array of objects with a pageContent property
combinedGenerations: 'combinedGenerations',
combinedRefinements: '',
continuePrompt: 'continue',
errors: [],
generationAttempts: 2,
generations: ['combined', 'Generations'],

View file

@ -16,13 +16,16 @@ import {
} from '../../../../evaluation/__mocks__/mock_anonymized_alerts';
import { getAnonymizedAlertsFromState } from './helpers/get_anonymized_alerts_from_state';
import { getChainWithFormatInstructions } from '../helpers/get_chain_with_format_instructions';
import { getDefaultAttackDiscoveryPrompt } from '../helpers/get_default_attack_discovery_prompt';
import { getDefaultRefinePrompt } from '../refine/helpers/get_default_refine_prompt';
import { GraphState } from '../../types';
import {
getParsedAttackDiscoveriesMock,
getRawAttackDiscoveriesMock,
} from '../../../../../../__mocks__/raw_attack_discoveries';
import {
ATTACK_DISCOVERY_CONTINUE,
ATTACK_DISCOVERY_DEFAULT,
ATTACK_DISCOVERY_REFINE,
} from '../../../../../prompt/prompts';
const attackDiscoveryTimestamp = '2024-10-11T17:55:59.702Z';
@ -49,10 +52,11 @@ let mockLlm: ActionsClientLlm;
const initialGraphState: GraphState = {
attackDiscoveries: null,
attackDiscoveryPrompt: getDefaultAttackDiscoveryPrompt(),
attackDiscoveryPrompt: ATTACK_DISCOVERY_DEFAULT,
anonymizedAlerts: [...mockAnonymizedAlerts],
combinedGenerations: '',
combinedRefinements: '',
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
errors: [],
generationAttempts: 0,
generations: [],
@ -61,13 +65,25 @@ const initialGraphState: GraphState = {
maxHallucinationFailures: 5,
maxRepeatedGenerations: 3,
refinements: [],
refinePrompt: getDefaultRefinePrompt(),
refinePrompt: ATTACK_DISCOVERY_REFINE,
replacements: {
...mockAnonymizedAlertsReplacements,
},
unrefinedResults: null,
};
const prompts = {
default: '',
refine: '',
continue: '',
detailsMarkdown: '',
entitySummaryMarkdown: '',
mitreAttackTactics: '',
summaryMarkdown: '',
title: '',
insights: '',
};
describe('getGenerateNode', () => {
beforeEach(() => {
jest.clearAllMocks();
@ -88,17 +104,20 @@ describe('getGenerateNode', () => {
const generateNode = getGenerateNode({
llm: mockLlm,
logger: mockLogger,
prompts,
});
expect(typeof generateNode).toBe('function');
});
it('invokes the chain with the expected alerts from state and formatting instructions', async () => {
const mockInvoke = getChainWithFormatInstructions(mockLlm).chain.invoke as jest.Mock;
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlm, prompts }).chain
.invoke as jest.Mock;
const generateNode = getGenerateNode({
llm: mockLlm,
logger: mockLogger,
prompts,
});
await generateNode(initialGraphState);
@ -121,7 +140,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
'You asked for some JSON, here it is:\n```json\n{"key": "value"}\n```\nI hope that works for you.';
const mockLlmWithResponse = new FakeLLM({ response }) as unknown as ActionsClientLlm;
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
.invoke as jest.Mock;
mockInvoke.mockResolvedValue(response);
@ -129,6 +148,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
const generateNode = getGenerateNode({
llm: mockLlmWithResponse,
logger: mockLogger,
prompts,
});
const state = await generateNode(initialGraphState);
@ -151,14 +171,15 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
const mockLlmWithHallucination = new FakeLLM({
response: hallucinatedResponse,
}) as unknown as ActionsClientLlm;
const mockInvoke = getChainWithFormatInstructions(mockLlmWithHallucination).chain
.invoke as jest.Mock;
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithHallucination, prompts })
.chain.invoke as jest.Mock;
mockInvoke.mockResolvedValue(hallucinatedResponse);
const generateNode = getGenerateNode({
llm: mockLlmWithHallucination,
logger: mockLogger,
prompts,
});
const withPreviousGenerations = {
@ -185,14 +206,17 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
const mockLlmWithRepeatedGenerations = new FakeLLM({
response: repeatedResponse,
}) as unknown as ActionsClientLlm;
const mockInvoke = getChainWithFormatInstructions(mockLlmWithRepeatedGenerations).chain
.invoke as jest.Mock;
const mockInvoke = getChainWithFormatInstructions({
llm: mockLlmWithRepeatedGenerations,
prompts,
}).chain.invoke as jest.Mock;
mockInvoke.mockResolvedValue(repeatedResponse);
const generateNode = getGenerateNode({
llm: mockLlmWithRepeatedGenerations,
logger: mockLogger,
prompts,
});
const withPreviousGenerations = {
@ -218,7 +242,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
const mockLlmWithResponse = new FakeLLM({
response,
}) as unknown as ActionsClientLlm;
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
.invoke as jest.Mock;
mockInvoke.mockResolvedValue(response);
@ -226,6 +250,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
const generateNode = getGenerateNode({
llm: mockLlmWithResponse,
logger: mockLogger,
prompts,
});
const withPreviousGenerations = {
@ -257,7 +282,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
const mockLlmWithResponse = new FakeLLM({
response: secondResponse,
}) as unknown as ActionsClientLlm;
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
.invoke as jest.Mock;
mockInvoke.mockResolvedValue(secondResponse);
@ -265,6 +290,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
const generateNode = getGenerateNode({
llm: mockLlmWithResponse,
logger: mockLogger,
prompts,
});
const withPreviousGenerations = {
@ -296,7 +322,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
const mockLlmWithResponse = new FakeLLM({
response: secondResponse,
}) as unknown as ActionsClientLlm;
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
.invoke as jest.Mock;
mockInvoke.mockResolvedValue(secondResponse);
@ -304,6 +330,7 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
const generateNode = getGenerateNode({
llm: mockLlmWithResponse,
logger: mockLogger,
prompts,
});
const withPreviousGenerations = {

View file

@ -8,6 +8,7 @@
import type { ActionsClientLlm } from '@kbn/langchain/server';
import type { Logger } from '@kbn/core/server';
import { GenerationPrompts } from '../helpers/prompts';
import { discardPreviousGenerations } from './helpers/discard_previous_generations';
import { extractJson } from '../helpers/extract_json';
import { getAnonymizedAlertsFromState } from './helpers/get_anonymized_alerts_from_state';
@ -23,9 +24,11 @@ import type { GraphState } from '../../types';
export const getGenerateNode = ({
llm,
logger,
prompts,
}: {
llm: ActionsClientLlm;
logger?: Logger;
prompts: GenerationPrompts;
}): ((state: GraphState) => Promise<GraphState>) => {
const generate = async (state: GraphState): Promise<GraphState> => {
logger?.debug(() => `---GENERATE---`);
@ -34,6 +37,7 @@ export const getGenerateNode = ({
const {
attackDiscoveryPrompt,
continuePrompt,
combinedGenerations,
generationAttempts,
generations,
@ -50,9 +54,13 @@ export const getGenerateNode = ({
anonymizedAlerts,
attackDiscoveryPrompt,
combinedMaybePartialResults: combinedGenerations,
continuePrompt,
});
const { chain, formatInstructions, llmType } = getChainWithFormatInstructions(llm);
const { chain, formatInstructions, llmType } = getChainWithFormatInstructions({
llm,
prompts,
});
logger?.debug(
() => `generate node is invoking the chain (${llmType}), attempt ${generationAttempts}`
@ -112,6 +120,7 @@ export const getGenerateNode = ({
llmType,
logger,
nodeName: 'generate',
prompts,
});
// use the unrefined results if we already reached the max number of retries:

View file

@ -13,72 +13,20 @@
*/
import { z } from '@kbn/zod';
import { GenerationPrompts } from '../../helpers/prompts';
export const SYNTAX = '{{ field.name fieldValue1 fieldValue2 fieldValueN }}';
const GOOD_SYNTAX_EXAMPLES =
'Examples of CORRECT syntax (includes field names and values): {{ host.name hostNameValue }} {{ user.name userNameValue }} {{ source.ip sourceIpValue }}';
const BAD_SYNTAX_EXAMPLES =
'Examples of INCORRECT syntax (bad, because the field names are not included): {{ hostNameValue }} {{ userNameValue }} {{ sourceIpValue }}';
const RECONNAISSANCE = 'Reconnaissance';
const INITIAL_ACCESS = 'Initial Access';
const EXECUTION = 'Execution';
const PERSISTENCE = 'Persistence';
const PRIVILEGE_ESCALATION = 'Privilege Escalation';
const DISCOVERY = 'Discovery';
const LATERAL_MOVEMENT = 'Lateral Movement';
const COMMAND_AND_CONTROL = 'Command and Control';
const EXFILTRATION = 'Exfiltration';
const MITRE_ATTACK_TACTICS = [
RECONNAISSANCE,
INITIAL_ACCESS,
EXECUTION,
PERSISTENCE,
PRIVILEGE_ESCALATION,
DISCOVERY,
LATERAL_MOVEMENT,
COMMAND_AND_CONTROL,
EXFILTRATION,
] as const;
export const AttackDiscoveriesGenerationSchema = z.object({
insights: z
.array(
z.object({
alertIds: z.string().array().describe(`The alert IDs that the insight is based on.`),
detailsMarkdown: z
.string()
.describe(
`A detailed insight with markdown, where each markdown bullet contains a description of what happened that reads like a story of the attack as it played out and always uses special ${SYNTAX} syntax for field names and values from the source data. ${GOOD_SYNTAX_EXAMPLES} ${BAD_SYNTAX_EXAMPLES}`
),
entitySummaryMarkdown: z
.string()
.optional()
.describe(
`A short (no more than a sentence) summary of the insight featuring only the host.name and user.name fields (when they are applicable), using the same ${SYNTAX} syntax`
),
mitreAttackTactics: z
.string()
.array()
.optional()
.describe(
`An array of MITRE ATT&CK tactic for the insight, using one of the following values: ${MITRE_ATTACK_TACTICS.join(
','
)}`
),
summaryMarkdown: z
.string()
.describe(`A markdown summary of insight, using the same ${SYNTAX} syntax`),
title: z
.string()
.describe(
'A short, no more than 7 words, title for the insight, NOT formatted with special syntax or markdown. This must be as brief as possible.'
),
})
)
.describe(
`Insights with markdown that always uses special ${SYNTAX} syntax for field names and values from the source data. ${GOOD_SYNTAX_EXAMPLES} ${BAD_SYNTAX_EXAMPLES}`
),
});
export const getAttackDiscoveriesGenerationSchema = (prompts: GenerationPrompts) =>
z.object({
insights: z
.array(
z.object({
alertIds: z.string().array().describe(`The alert IDs that the insight is based on.`),
detailsMarkdown: z.string().describe(prompts.detailsMarkdown),
entitySummaryMarkdown: z.string().optional().describe(prompts.entitySummaryMarkdown),
mitreAttackTactics: z.string().array().optional().describe(prompts.mitreAttackTactics),
summaryMarkdown: z.string().describe(prompts.summaryMarkdown),
title: z.string().describe(prompts.title),
})
)
.describe(prompts.insights),
});

View file

@ -9,6 +9,23 @@ import { FakeLLM } from '@langchain/core/utils/testing';
import type { ActionsClientLlm } from '@kbn/langchain/server';
import { getChainWithFormatInstructions } from '.';
import {
ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
ATTACK_DISCOVERY_GENERATION_INSIGHTS,
ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
ATTACK_DISCOVERY_GENERATION_TITLE,
} from '../../../../../../prompt/prompts';
const prompts = {
detailsMarkdown: ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
entitySummaryMarkdown: ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
mitreAttackTactics: ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
summaryMarkdown: ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
title: ATTACK_DISCOVERY_GENERATION_TITLE,
insights: ATTACK_DISCOVERY_GENERATION_INSIGHTS,
};
describe('getChainWithFormatInstructions', () => {
const mockLlm = new FakeLLM({
@ -32,7 +49,7 @@ Here is the JSON Schema instance your output must adhere to. Include the enclosi
\`\`\`
`;
const chainWithFormatInstructions = getChainWithFormatInstructions(mockLlm);
const chainWithFormatInstructions = getChainWithFormatInstructions({ llm: mockLlm, prompts });
expect(chainWithFormatInstructions).toEqual({
chain: expect.any(Object),
formatInstructions: expectedFormatInstructions,

View file

@ -9,6 +9,7 @@ import type { ActionsClientLlm } from '@kbn/langchain/server';
import { ChatPromptTemplate } from '@langchain/core/prompts';
import { Runnable } from '@langchain/core/runnables';
import { GenerationPrompts } from '../prompts';
import { getOutputParser } from '../get_output_parser';
interface GetChainWithFormatInstructions {
@ -17,10 +18,14 @@ interface GetChainWithFormatInstructions {
llmType: string;
}
export const getChainWithFormatInstructions = (
llm: ActionsClientLlm
): GetChainWithFormatInstructions => {
const outputParser = getOutputParser();
export const getChainWithFormatInstructions = ({
llm,
prompts,
}: {
llm: ActionsClientLlm;
prompts: GenerationPrompts;
}): GetChainWithFormatInstructions => {
const outputParser = getOutputParser(prompts);
const formatInstructions = outputParser.getFormatInstructions();
const prompt = ChatPromptTemplate.fromTemplate(

View file

@ -6,12 +6,14 @@
*/
import { getCombinedAttackDiscoveryPrompt } from '.';
import { ATTACK_DISCOVERY_CONTINUE } from '../../../../../../prompt/prompts';
describe('getCombinedAttackDiscoveryPrompt', () => {
it('returns the initial query when there are no partial results', () => {
const result = getCombinedAttackDiscoveryPrompt({
anonymizedAlerts: ['alert1', 'alert2'],
attackDiscoveryPrompt: 'attackDiscoveryPrompt',
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
combinedMaybePartialResults: '',
});
@ -31,6 +33,7 @@ alert2
const result = getCombinedAttackDiscoveryPrompt({
anonymizedAlerts: ['alert1', 'alert2'],
attackDiscoveryPrompt: 'attackDiscoveryPrompt',
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
combinedMaybePartialResults: 'partialResults',
});

View file

@ -8,7 +8,6 @@
import { isEmpty } from 'lodash/fp';
import { getAlertsContextPrompt } from '../../generate/helpers/get_alerts_context_prompt';
import { getContinuePrompt } from '../get_continue_prompt';
/**
* Returns the the initial query, or the initial query combined with a
@ -18,11 +17,13 @@ export const getCombinedAttackDiscoveryPrompt = ({
anonymizedAlerts,
attackDiscoveryPrompt,
combinedMaybePartialResults,
continuePrompt,
}: {
anonymizedAlerts: string[];
attackDiscoveryPrompt: string;
/** combined results that may contain incomplete JSON */
combinedMaybePartialResults: string;
continuePrompt: string;
}): string => {
const alertsContextPrompt = getAlertsContextPrompt({
anonymizedAlerts,
@ -33,7 +34,7 @@ export const getCombinedAttackDiscoveryPrompt = ({
? alertsContextPrompt // no partial results yet
: `${alertsContextPrompt}
${getContinuePrompt()}
${continuePrompt}
"""
${combinedMaybePartialResults}

View file

@ -1,22 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { getContinuePrompt } from '.';
describe('getContinuePrompt', () => {
it('returns the expected prompt string', () => {
const expectedPrompt = `Continue exactly where you left off in the JSON output below, generating only the additional JSON output when it's required to complete your work. The additional JSON output MUST ALWAYS follow these rules:
1) it MUST conform to the schema above, because it will be checked against the JSON schema
2) it MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds), because it will be parsed as JSON
3) it MUST NOT repeat any the previous output, because that would prevent partial results from being combined
4) it MUST NOT restart from the beginning, because that would prevent partial results from being combined
5) it MUST NOT be prefixed or suffixed with additional text outside of the JSON, because that would prevent it from being combined and parsed as JSON:
`;
expect(getContinuePrompt()).toBe(expectedPrompt);
});
});

View file

@ -1,15 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export const getContinuePrompt =
(): string => `Continue exactly where you left off in the JSON output below, generating only the additional JSON output when it's required to complete your work. The additional JSON output MUST ALWAYS follow these rules:
1) it MUST conform to the schema above, because it will be checked against the JSON schema
2) it MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds), because it will be parsed as JSON
3) it MUST NOT repeat any the previous output, because that would prevent partial results from being combined
4) it MUST NOT restart from the beginning, because that would prevent partial results from being combined
5) it MUST NOT be prefixed or suffixed with additional text outside of the JSON, because that would prevent it from being combined and parsed as JSON:
`;

View file

@ -1,16 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { getDefaultAttackDiscoveryPrompt } from '.';
describe('getDefaultAttackDiscoveryPrompt', () => {
it('returns the default attack discovery prompt', () => {
expect(getDefaultAttackDiscoveryPrompt()).toEqual(
"You are a cyber security analyst tasked with analyzing security events from Elastic Security to identify and report on potential cyber attacks or progressions. Your report should focus on high-risk incidents that could severely impact the organization, rather than isolated alerts. Present your findings in a way that can be easily understood by anyone, regardless of their technical expertise, as if you were briefing the CISO. Break down your response into sections based on timing, hosts, and users involved. When correlating alerts, use kibana.alert.original_time when it's available, otherwise use @timestamp. Include appropriate context about the affected hosts and users. Describe how the attack progression might have occurred and, if feasible, attribute it to known threat groups. Prioritize high and critical alerts, but include lower-severity alerts if desired. In the description field, provide as much detail as possible, in a bulleted list explaining any attack progressions. Accuracy is of utmost importance. You MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds)."
);
});
});

View file

@ -1,9 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export const getDefaultAttackDiscoveryPrompt = (): string =>
"You are a cyber security analyst tasked with analyzing security events from Elastic Security to identify and report on potential cyber attacks or progressions. Your report should focus on high-risk incidents that could severely impact the organization, rather than isolated alerts. Present your findings in a way that can be easily understood by anyone, regardless of their technical expertise, as if you were briefing the CISO. Break down your response into sections based on timing, hosts, and users involved. When correlating alerts, use kibana.alert.original_time when it's available, otherwise use @timestamp. Include appropriate context about the affected hosts and users. Describe how the attack progression might have occurred and, if feasible, attribute it to known threat groups. Prioritize high and critical alerts, but include lower-severity alerts if desired. In the description field, provide as much detail as possible, in a bulleted list explaining any attack progressions. Accuracy is of utmost importance. You MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds).";

View file

@ -5,10 +5,27 @@
* 2.0.
*/
import { getOutputParser } from '.';
import {
ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
ATTACK_DISCOVERY_GENERATION_INSIGHTS,
ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
ATTACK_DISCOVERY_GENERATION_TITLE,
} from '../../../../../../prompt/prompts';
const prompts = {
detailsMarkdown: ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
entitySummaryMarkdown: ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
mitreAttackTactics: ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
summaryMarkdown: ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
title: ATTACK_DISCOVERY_GENERATION_TITLE,
insights: ATTACK_DISCOVERY_GENERATION_INSIGHTS,
};
describe('getOutputParser', () => {
it('returns a structured output parser with the expected format instructions', () => {
const outputParser = getOutputParser();
const outputParser = getOutputParser(prompts);
const expected = `You must format your output as a JSON value that adheres to a given \"JSON Schema\" instance.

View file

@ -7,7 +7,8 @@
import { StructuredOutputParser } from 'langchain/output_parsers';
import { AttackDiscoveriesGenerationSchema } from '../../generate/schema';
import { GenerationPrompts } from '../prompts';
import { getAttackDiscoveriesGenerationSchema } from '../../generate/schema';
export const getOutputParser = () =>
StructuredOutputParser.fromZodSchema(AttackDiscoveriesGenerationSchema);
export const getOutputParser = (prompts: GenerationPrompts) =>
StructuredOutputParser.fromZodSchema(getAttackDiscoveriesGenerationSchema(prompts));

View file

@ -10,6 +10,15 @@ import type { Logger } from '@kbn/core/server';
import { parseCombinedOrThrow } from '.';
import { getRawAttackDiscoveriesMock } from '../../../../../../../__mocks__/raw_attack_discoveries';
const prompts = {
detailsMarkdown: '',
entitySummaryMarkdown: '',
mitreAttackTactics: '',
summaryMarkdown: '',
title: '',
insights: '',
};
describe('parseCombinedOrThrow', () => {
const mockLogger: Logger = {
debug: jest.fn(),
@ -28,6 +37,7 @@ describe('parseCombinedOrThrow', () => {
nodeName,
llmType,
logger: mockLogger,
prompts,
};
it('returns an Attack discovery for each insight in a valid combined response', () => {

View file

@ -8,9 +8,10 @@
import type { Logger } from '@kbn/core/server';
import type { AttackDiscovery } from '@kbn/elastic-assistant-common';
import { GenerationPrompts } from '../prompts';
import { addTrailingBackticksIfNecessary } from '../add_trailing_backticks_if_necessary';
import { extractJson } from '../extract_json';
import { AttackDiscoveriesGenerationSchema } from '../../generate/schema';
import { getAttackDiscoveriesGenerationSchema } from '../../generate/schema';
export const parseCombinedOrThrow = ({
combinedResponse,
@ -18,6 +19,7 @@ export const parseCombinedOrThrow = ({
llmType,
logger,
nodeName,
prompts,
}: {
/** combined responses that maybe valid JSON */
combinedResponse: string;
@ -25,6 +27,7 @@ export const parseCombinedOrThrow = ({
nodeName: string;
llmType: string;
logger?: Logger;
prompts: GenerationPrompts;
}): AttackDiscovery[] => {
const timestamp = new Date().toISOString();
@ -42,7 +45,7 @@ export const parseCombinedOrThrow = ({
`${nodeName} node is validating combined response (${llmType}) from attempt ${generationAttempts}`
);
const validatedResponse = AttackDiscoveriesGenerationSchema.parse(unvalidatedParsed);
const validatedResponse = getAttackDiscoveriesGenerationSchema(prompts).parse(unvalidatedParsed);
logger?.debug(
() =>

View file

@ -0,0 +1,118 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { ActionsClient } from '@kbn/actions-plugin/server';
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
import { PublicMethodsOf } from '@kbn/utility-types';
import { getAttackDiscoveryPrompts } from '.';
import { getPromptsByGroupId, promptDictionary } from '../../../../../../prompt';
import { promptGroupId } from '../../../../../../prompt/local_prompt_object';
jest.mock('../../../../../../prompt', () => {
const original = jest.requireActual('../../../../../../prompt');
return {
...original,
getPromptsByGroupId: jest.fn(),
};
});
const mockGetPromptsByGroupId = getPromptsByGroupId as jest.Mock;
describe('getAttackDiscoveryPrompts', () => {
const actionsClient = {} as jest.Mocked<PublicMethodsOf<ActionsClient>>;
const savedObjectsClient = {} as jest.Mocked<SavedObjectsClientContract>;
beforeEach(() => {
jest.clearAllMocks();
mockGetPromptsByGroupId.mockResolvedValue([
{ promptId: promptDictionary.attackDiscoveryDefault, prompt: 'Default Prompt' },
{ promptId: promptDictionary.attackDiscoveryRefine, prompt: 'Refine Prompt' },
{ promptId: promptDictionary.attackDiscoveryContinue, prompt: 'Continue Prompt' },
{
promptId: promptDictionary.attackDiscoveryDetailsMarkdown,
prompt: 'Details Markdown Prompt',
},
{
promptId: promptDictionary.attackDiscoveryEntitySummaryMarkdown,
prompt: 'Entity Summary Markdown Prompt',
},
{
promptId: promptDictionary.attackDiscoveryMitreAttackTactics,
prompt: 'Mitre Attack Tactics Prompt',
},
{
promptId: promptDictionary.attackDiscoverySummaryMarkdown,
prompt: 'Summary Markdown Prompt',
},
{ promptId: promptDictionary.attackDiscoveryGenerationTitle, prompt: 'Title Prompt' },
{ promptId: promptDictionary.attackDiscoveryGenerationInsights, prompt: 'Insights Prompt' },
]);
});
it('should return all prompts', async () => {
const result = await getAttackDiscoveryPrompts({
actionsClient,
connectorId: 'test-connector-id',
savedObjectsClient,
model: '2',
provider: 'gemini',
});
expect(mockGetPromptsByGroupId).toHaveBeenCalledWith(
expect.objectContaining({
connectorId: 'test-connector-id',
promptGroupId: promptGroupId.attackDiscovery,
model: '2',
provider: 'gemini',
promptIds: [
promptDictionary.attackDiscoveryDefault,
promptDictionary.attackDiscoveryRefine,
promptDictionary.attackDiscoveryContinue,
promptDictionary.attackDiscoveryDetailsMarkdown,
promptDictionary.attackDiscoveryEntitySummaryMarkdown,
promptDictionary.attackDiscoveryMitreAttackTactics,
promptDictionary.attackDiscoverySummaryMarkdown,
promptDictionary.attackDiscoveryGenerationTitle,
promptDictionary.attackDiscoveryGenerationInsights,
],
})
);
expect(result).toEqual({
default: 'Default Prompt',
refine: 'Refine Prompt',
continue: 'Continue Prompt',
detailsMarkdown: 'Details Markdown Prompt',
entitySummaryMarkdown: 'Entity Summary Markdown Prompt',
mitreAttackTactics: 'Mitre Attack Tactics Prompt',
summaryMarkdown: 'Summary Markdown Prompt',
title: 'Title Prompt',
insights: 'Insights Prompt',
});
});
it('should return empty strings for missing prompts', async () => {
mockGetPromptsByGroupId.mockResolvedValue([]);
const result = await getAttackDiscoveryPrompts({
actionsClient,
connectorId: 'test-connector-id',
savedObjectsClient,
});
expect(result).toEqual({
default: '',
refine: '',
continue: '',
detailsMarkdown: '',
entitySummaryMarkdown: '',
mitreAttackTactics: '',
summaryMarkdown: '',
title: '',
insights: '',
});
});
});

View file

@ -0,0 +1,101 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { PublicMethodsOf } from '@kbn/utility-types';
import type { ActionsClient } from '@kbn/actions-plugin/server';
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
import type { Connector } from '@kbn/actions-plugin/server/application/connector/types';
import { getPromptsByGroupId, promptDictionary } from '../../../../../../prompt';
import { promptGroupId } from '../../../../../../prompt/local_prompt_object';
export interface AttackDiscoveryPrompts {
default: string;
refine: string;
continue: string;
}
export interface GenerationPrompts {
detailsMarkdown: string;
entitySummaryMarkdown: string;
mitreAttackTactics: string;
summaryMarkdown: string;
title: string;
insights: string;
}
export interface CombinedPrompts extends AttackDiscoveryPrompts, GenerationPrompts {}
export const getAttackDiscoveryPrompts = async ({
actionsClient,
connector,
connectorId,
model,
provider,
savedObjectsClient,
}: {
actionsClient: PublicMethodsOf<ActionsClient>;
connector?: Connector;
connectorId: string;
model?: string;
provider?: string;
savedObjectsClient: SavedObjectsClientContract;
}): Promise<CombinedPrompts> => {
const prompts = await getPromptsByGroupId({
actionsClient,
connector,
connectorId,
// if in future oss has different prompt, add it as model here
model,
promptGroupId: promptGroupId.attackDiscovery,
promptIds: [
promptDictionary.attackDiscoveryDefault,
promptDictionary.attackDiscoveryRefine,
promptDictionary.attackDiscoveryContinue,
promptDictionary.attackDiscoveryDetailsMarkdown,
promptDictionary.attackDiscoveryEntitySummaryMarkdown,
promptDictionary.attackDiscoveryMitreAttackTactics,
promptDictionary.attackDiscoverySummaryMarkdown,
promptDictionary.attackDiscoveryGenerationTitle,
promptDictionary.attackDiscoveryGenerationInsights,
],
provider,
savedObjectsClient,
});
return {
default:
prompts.find((prompt) => prompt.promptId === promptDictionary.attackDiscoveryDefault)
?.prompt || '',
refine:
prompts.find((prompt) => prompt.promptId === promptDictionary.attackDiscoveryRefine)
?.prompt || '',
continue:
prompts.find((prompt) => prompt.promptId === promptDictionary.attackDiscoveryContinue)
?.prompt || '',
detailsMarkdown:
prompts.find((prompt) => prompt.promptId === promptDictionary.attackDiscoveryDetailsMarkdown)
?.prompt || '',
entitySummaryMarkdown:
prompts.find(
(prompt) => prompt.promptId === promptDictionary.attackDiscoveryEntitySummaryMarkdown
)?.prompt || '',
mitreAttackTactics:
prompts.find(
(prompt) => prompt.promptId === promptDictionary.attackDiscoveryMitreAttackTactics
)?.prompt || '',
summaryMarkdown:
prompts.find((prompt) => prompt.promptId === promptDictionary.attackDiscoverySummaryMarkdown)
?.prompt || '',
title:
prompts.find((prompt) => prompt.promptId === promptDictionary.attackDiscoveryGenerationTitle)
?.prompt || '',
insights:
prompts.find(
(prompt) => prompt.promptId === promptDictionary.attackDiscoveryGenerationInsights
)?.prompt || '',
};
};

View file

@ -15,6 +15,7 @@ const initialState: GraphState = {
attackDiscoveryPrompt: 'attackDiscoveryPrompt',
combinedGenerations: 'generation1generation2',
combinedRefinements: 'refinement1', // <-- existing refinements
continuePrompt: 'continue',
errors: [],
generationAttempts: 3,
generations: ['generation1', 'generation2'],

View file

@ -7,13 +7,14 @@
import { getCombinedRefinePrompt } from '.';
import { mockAttackDiscoveries } from '../../../../../../evaluation/__mocks__/mock_attack_discoveries';
import { getContinuePrompt } from '../../../helpers/get_continue_prompt';
import { ATTACK_DISCOVERY_CONTINUE } from '../../../../../../../prompt/prompts';
describe('getCombinedRefinePrompt', () => {
it('returns the base query when combinedRefinements is empty', () => {
const result = getCombinedRefinePrompt({
attackDiscoveryPrompt: 'Initial query',
combinedRefinements: '',
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
refinePrompt: 'Refine prompt',
unrefinedResults: [...mockAttackDiscoveries],
});
@ -33,6 +34,7 @@ ${JSON.stringify(mockAttackDiscoveries, null, 2)}
const result = getCombinedRefinePrompt({
attackDiscoveryPrompt: 'Initial query',
combinedRefinements: 'Combined refinements',
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
refinePrompt: 'Refine prompt',
unrefinedResults: [...mockAttackDiscoveries],
});
@ -47,7 +49,7 @@ ${JSON.stringify(mockAttackDiscoveries, null, 2)}
${getContinuePrompt()}
${ATTACK_DISCOVERY_CONTINUE}
"""
Combined refinements
@ -60,6 +62,7 @@ Combined refinements
const result = getCombinedRefinePrompt({
attackDiscoveryPrompt: 'Initial query',
combinedRefinements: '',
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
refinePrompt: 'Refine prompt',
unrefinedResults: null,
});

View file

@ -8,19 +8,19 @@
import type { AttackDiscovery } from '@kbn/elastic-assistant-common';
import { isEmpty } from 'lodash/fp';
import { getContinuePrompt } from '../../../helpers/get_continue_prompt';
/**
* Returns a prompt that combines the initial query, a refine prompt, and partial results
*/
export const getCombinedRefinePrompt = ({
attackDiscoveryPrompt,
combinedRefinements,
continuePrompt,
refinePrompt,
unrefinedResults,
}: {
attackDiscoveryPrompt: string;
combinedRefinements: string;
continuePrompt: string;
refinePrompt: string;
unrefinedResults: AttackDiscovery[] | null;
}): string => {
@ -38,7 +38,7 @@ ${JSON.stringify(unrefinedResults, null, 2)}
? baseQuery // no partial results yet
: `${baseQuery}
${getContinuePrompt()}
${continuePrompt}
"""
${combinedRefinements}

View file

@ -1,19 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { getDefaultRefinePrompt } from '.';
describe('getDefaultRefinePrompt', () => {
it('returns the default refine prompt string', () => {
const result = getDefaultRefinePrompt();
expect(result)
.toEqual(`You previously generated the following insights, but sometimes they represent the same attack.
Combine the insights below, when they represent the same attack; leave any insights that are not combined unchanged:`);
});
});

View file

@ -1,11 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export const getDefaultRefinePrompt =
(): string => `You previously generated the following insights, but sometimes they represent the same attack.
Combine the insights below, when they represent the same attack; leave any insights that are not combined unchanged:`;

View file

@ -16,16 +16,26 @@ import {
mockAnonymizedAlertsReplacements,
} from '../../../../evaluation/__mocks__/mock_anonymized_alerts';
import { getChainWithFormatInstructions } from '../helpers/get_chain_with_format_instructions';
import { getDefaultAttackDiscoveryPrompt } from '../helpers/get_default_attack_discovery_prompt';
import { getDefaultRefinePrompt } from './helpers/get_default_refine_prompt';
import { GraphState } from '../../types';
import {
getParsedAttackDiscoveriesMock,
getRawAttackDiscoveriesMock,
} from '../../../../../../__mocks__/raw_attack_discoveries';
import {
ATTACK_DISCOVERY_CONTINUE,
ATTACK_DISCOVERY_DEFAULT,
ATTACK_DISCOVERY_REFINE,
} from '../../../../../prompt/prompts';
const attackDiscoveryTimestamp = '2024-10-11T17:55:59.702Z';
const prompts = {
detailsMarkdown: '',
entitySummaryMarkdown: '',
mitreAttackTactics: '',
summaryMarkdown: '',
title: '',
insights: '',
};
export const mockUnrefinedAttackDiscoveries: AttackDiscovery[] = [
{
title: 'unrefinedTitle1',
@ -67,10 +77,11 @@ let mockLlm: ActionsClientLlm;
const initialGraphState: GraphState = {
attackDiscoveries: null,
attackDiscoveryPrompt: getDefaultAttackDiscoveryPrompt(),
attackDiscoveryPrompt: ATTACK_DISCOVERY_DEFAULT,
anonymizedAlerts: [...mockAnonymizedAlerts],
combinedGenerations: 'gen1',
combinedRefinements: '',
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
errors: [],
generationAttempts: 1,
generations: ['gen1'],
@ -79,7 +90,7 @@ const initialGraphState: GraphState = {
maxHallucinationFailures: 5,
maxRepeatedGenerations: 3,
refinements: [],
refinePrompt: getDefaultRefinePrompt(),
refinePrompt: ATTACK_DISCOVERY_REFINE,
replacements: {
...mockAnonymizedAlertsReplacements,
},
@ -106,17 +117,20 @@ describe('getRefineNode', () => {
const refineNode = getRefineNode({
llm: mockLlm,
logger: mockLogger,
prompts,
});
expect(typeof refineNode).toBe('function');
});
it('invokes the chain with the unrefinedResults from state and format instructions', async () => {
const mockInvoke = getChainWithFormatInstructions(mockLlm).chain.invoke as jest.Mock;
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlm, prompts }).chain
.invoke as jest.Mock;
const refineNode = getRefineNode({
llm: mockLlm,
logger: mockLogger,
prompts,
});
await refineNode(initialGraphState);
@ -125,7 +139,7 @@ describe('getRefineNode', () => {
format_instructions: ['mock format instructions'],
query: `${initialGraphState.attackDiscoveryPrompt}
${getDefaultRefinePrompt()}
${ATTACK_DISCOVERY_REFINE}
\"\"\"
${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
@ -140,7 +154,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
'You asked for some JSON, here it is:\n```json\n{"key": "value"}\n```\nI hope that works for you.';
const mockLlmWithResponse = new FakeLLM({ response }) as unknown as ActionsClientLlm;
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
.invoke as jest.Mock;
mockInvoke.mockResolvedValue(response);
@ -148,6 +162,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
const refineNode = getRefineNode({
llm: mockLlm,
logger: mockLogger,
prompts,
});
const state = await refineNode(initialGraphState);
@ -170,14 +185,15 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
const mockLlmWithHallucination = new FakeLLM({
response: hallucinatedResponse,
}) as unknown as ActionsClientLlm;
const mockInvoke = getChainWithFormatInstructions(mockLlmWithHallucination).chain
.invoke as jest.Mock;
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithHallucination, prompts })
.chain.invoke as jest.Mock;
mockInvoke.mockResolvedValue(hallucinatedResponse);
const refineNode = getRefineNode({
llm: mockLlmWithHallucination,
logger: mockLogger,
prompts,
});
const withPreviousGenerations = {
@ -203,14 +219,17 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
const mockLlmWithRepeatedGenerations = new FakeLLM({
response: repeatedResponse,
}) as unknown as ActionsClientLlm;
const mockInvoke = getChainWithFormatInstructions(mockLlmWithRepeatedGenerations).chain
.invoke as jest.Mock;
const mockInvoke = getChainWithFormatInstructions({
llm: mockLlmWithRepeatedGenerations,
prompts,
}).chain.invoke as jest.Mock;
mockInvoke.mockResolvedValue(repeatedResponse);
const refineNode = getRefineNode({
llm: mockLlmWithRepeatedGenerations,
logger: mockLogger,
prompts,
});
const withPreviousGenerations = {
@ -236,7 +255,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
const mockLlmWithResponse = new FakeLLM({
response,
}) as unknown as ActionsClientLlm;
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
.invoke as jest.Mock;
mockInvoke.mockResolvedValue(response);
@ -244,6 +263,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
const refineNode = getRefineNode({
llm: mockLlmWithResponse,
logger: mockLogger,
prompts,
});
const withPreviousGenerations = {
@ -275,7 +295,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
const mockLlmWithResponse = new FakeLLM({
response: secondResponse,
}) as unknown as ActionsClientLlm;
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
.invoke as jest.Mock;
mockInvoke.mockResolvedValue(secondResponse);
@ -283,6 +303,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
const refineNode = getRefineNode({
llm: mockLlmWithResponse,
logger: mockLogger,
prompts,
});
const withPreviousGenerations = {
@ -309,7 +330,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
const mockLlmWithResponse = new FakeLLM({
response,
}) as unknown as ActionsClientLlm;
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
const mockInvoke = getChainWithFormatInstructions({ llm: mockLlmWithResponse, prompts }).chain
.invoke as jest.Mock;
mockInvoke.mockResolvedValue(response);
@ -317,6 +338,7 @@ ${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
const refineNode = getRefineNode({
llm: mockLlmWithResponse,
logger: mockLogger,
prompts,
});
const withPreviousGenerations = {

View file

@ -6,8 +6,9 @@
*/
import type { ActionsClientLlm } from '@kbn/langchain/server';
import type { Logger } from '@kbn/core/server';
import { Logger } from '@kbn/core/server';
import { GenerationPrompts } from '../helpers/prompts';
import { discardPreviousRefinements } from './helpers/discard_previous_refinements';
import { extractJson } from '../helpers/extract_json';
import { getChainWithFormatInstructions } from '../helpers/get_chain_with_format_instructions';
@ -24,9 +25,11 @@ import type { GraphState } from '../../types';
export const getRefineNode = ({
llm,
logger,
prompts,
}: {
llm: ActionsClientLlm;
logger?: Logger;
prompts: GenerationPrompts;
}): ((state: GraphState) => Promise<GraphState>) => {
const refine = async (state: GraphState): Promise<GraphState> => {
logger?.debug(() => '---REFINE---');
@ -34,6 +37,7 @@ export const getRefineNode = ({
const {
attackDiscoveryPrompt,
combinedRefinements,
continuePrompt,
generationAttempts,
hallucinationFailures,
maxGenerationAttempts,
@ -51,11 +55,15 @@ export const getRefineNode = ({
const query = getCombinedRefinePrompt({
attackDiscoveryPrompt,
combinedRefinements,
continuePrompt,
refinePrompt,
unrefinedResults,
});
const { chain, formatInstructions, llmType } = getChainWithFormatInstructions(llm);
const { chain, formatInstructions, llmType } = getChainWithFormatInstructions({
llm,
prompts,
});
logger?.debug(
() => `refine node is invoking the chain (${llmType}), attempt ${generationAttempts}`
@ -115,6 +123,7 @@ export const getRefineNode = ({
llmType,
logger,
nodeName: 'refine',
prompts,
});
return {

View file

@ -11,16 +11,20 @@ import { Replacements } from '@kbn/elastic-assistant-common';
import { getRetrieveAnonymizedAlertsNode } from '.';
import { mockAnonymizedAlerts } from '../../../../evaluation/__mocks__/mock_anonymized_alerts';
import { getDefaultAttackDiscoveryPrompt } from '../helpers/get_default_attack_discovery_prompt';
import { getDefaultRefinePrompt } from '../refine/helpers/get_default_refine_prompt';
import type { GraphState } from '../../types';
import {
ATTACK_DISCOVERY_CONTINUE,
ATTACK_DISCOVERY_DEFAULT,
ATTACK_DISCOVERY_REFINE,
} from '../../../../../prompt/prompts';
const initialGraphState: GraphState = {
attackDiscoveries: null,
attackDiscoveryPrompt: getDefaultAttackDiscoveryPrompt(),
attackDiscoveryPrompt: ATTACK_DISCOVERY_DEFAULT,
anonymizedAlerts: [],
combinedGenerations: '',
combinedRefinements: '',
continuePrompt: ATTACK_DISCOVERY_CONTINUE,
errors: [],
generationAttempts: 0,
generations: [],
@ -29,7 +33,7 @@ const initialGraphState: GraphState = {
maxHallucinationFailures: 5,
maxRepeatedGenerations: 3,
refinements: [],
refinePrompt: getDefaultRefinePrompt(),
refinePrompt: ATTACK_DISCOVERY_REFINE,
replacements: {},
unrefinedResults: null,
};

View file

@ -11,110 +11,117 @@ import {
DEFAULT_MAX_HALLUCINATION_FAILURES,
DEFAULT_MAX_REPEATED_GENERATIONS,
} from '../constants';
import { getDefaultAttackDiscoveryPrompt } from '../nodes/helpers/get_default_attack_discovery_prompt';
import { getDefaultRefinePrompt } from '../nodes/refine/helpers/get_default_refine_prompt';
const defaultAttackDiscoveryPrompt = getDefaultAttackDiscoveryPrompt();
const defaultRefinePrompt = getDefaultRefinePrompt();
import {
ATTACK_DISCOVERY_CONTINUE,
ATTACK_DISCOVERY_DEFAULT,
ATTACK_DISCOVERY_REFINE,
} from '../../../../prompt/prompts';
const defaultAttackDiscoveryPrompt = ATTACK_DISCOVERY_DEFAULT;
const defaultRefinePrompt = ATTACK_DISCOVERY_REFINE;
const prompts = {
continue: ATTACK_DISCOVERY_CONTINUE,
default: defaultAttackDiscoveryPrompt,
refine: defaultRefinePrompt,
};
describe('getDefaultGraphState', () => {
it('returns the expected default attackDiscoveries', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.attackDiscoveries?.default?.()).toBeNull();
});
it('returns the expected default attackDiscoveryPrompt', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.attackDiscoveryPrompt?.default?.()).toEqual(defaultAttackDiscoveryPrompt);
});
it('returns the expected default empty collection of anonymizedAlerts', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.anonymizedAlerts?.default?.()).toHaveLength(0);
});
it('returns the expected default combinedGenerations state', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.combinedGenerations?.default?.()).toBe('');
});
it('returns the expected default combinedRefinements state', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.combinedRefinements?.default?.()).toBe('');
});
it('returns the expected default errors state', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.errors?.default?.()).toHaveLength(0);
});
it('return the expected default generationAttempts state', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.generationAttempts?.default?.()).toBe(0);
});
it('returns the expected default generations state', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.generations?.default?.()).toHaveLength(0);
});
it('returns the expected default hallucinationFailures state', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.hallucinationFailures?.default?.()).toBe(0);
});
it('returns the expected default refinePrompt state', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.refinePrompt?.default?.()).toEqual(defaultRefinePrompt);
});
it('returns the expected default maxGenerationAttempts state', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.maxGenerationAttempts?.default?.()).toBe(DEFAULT_MAX_GENERATION_ATTEMPTS);
});
it('returns the expected default maxHallucinationFailures state', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.maxHallucinationFailures?.default?.()).toBe(DEFAULT_MAX_HALLUCINATION_FAILURES);
});
it('returns the expected default maxRepeatedGenerations state', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.maxRepeatedGenerations?.default?.()).toBe(DEFAULT_MAX_REPEATED_GENERATIONS);
});
it('returns the expected default refinements state', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.refinements?.default?.()).toHaveLength(0);
});
it('returns the expected default replacements state', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.replacements?.default?.()).toEqual({});
});
it('returns the expected default unrefinedResults state', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.unrefinedResults?.default?.()).toBeNull();
});
it('returns the expected default end', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.end?.default?.()).toBeUndefined();
});
@ -122,13 +129,13 @@ describe('getDefaultGraphState', () => {
it('returns the expected end when it is provided', () => {
const end = '2025-01-02T00:00:00.000Z';
const state = getDefaultGraphState({ end });
const state = getDefaultGraphState({ prompts, end });
expect(state.end?.default?.()).toEqual(end);
});
it('returns the expected default filter to be undefined', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.filter?.default?.()).toBeUndefined();
});
@ -155,13 +162,13 @@ describe('getDefaultGraphState', () => {
},
};
const state = getDefaultGraphState({ filter });
const state = getDefaultGraphState({ prompts, filter });
expect(state.filter?.default?.()).toEqual(filter);
});
it('returns the expected default start to be undefined', () => {
const state = getDefaultGraphState();
const state = getDefaultGraphState({ prompts });
expect(state.start?.default?.()).toBeUndefined();
});
@ -169,7 +176,7 @@ describe('getDefaultGraphState', () => {
it('returns the expected start when it is provided', () => {
const start = '2025-01-01T00:00:00.000Z';
const state = getDefaultGraphState({ start });
const state = getDefaultGraphState({ prompts, start });
expect(state.start?.default?.()).toEqual(start);
});

View file

@ -9,33 +9,34 @@ import { AttackDiscovery, Replacements } from '@kbn/elastic-assistant-common';
import type { Document } from '@langchain/core/documents';
import type { StateGraphArgs } from '@langchain/langgraph';
import { AttackDiscoveryPrompts } from '../nodes/helpers/prompts';
import {
DEFAULT_MAX_GENERATION_ATTEMPTS,
DEFAULT_MAX_HALLUCINATION_FAILURES,
DEFAULT_MAX_REPEATED_GENERATIONS,
} from '../constants';
import { getDefaultAttackDiscoveryPrompt } from '../nodes/helpers/get_default_attack_discovery_prompt';
import { getDefaultRefinePrompt } from '../nodes/refine/helpers/get_default_refine_prompt';
import type { GraphState } from '../types';
export interface Options {
end?: string;
filter?: Record<string, unknown> | null;
prompts: AttackDiscoveryPrompts;
start?: string;
}
export const getDefaultGraphState = ({
end,
filter,
prompts,
start,
}: Options | undefined = {}): StateGraphArgs<GraphState>['channels'] => ({
}: Options): StateGraphArgs<GraphState>['channels'] => ({
attackDiscoveries: {
value: (x: AttackDiscovery[] | null, y?: AttackDiscovery[] | null) => y ?? x,
default: () => null,
},
attackDiscoveryPrompt: {
value: (x: string, y?: string) => y ?? x,
default: () => getDefaultAttackDiscoveryPrompt(),
default: () => prompts.default,
},
anonymizedAlerts: {
value: (x: Document[], y?: Document[]) => y ?? x,
@ -49,6 +50,10 @@ export const getDefaultGraphState = ({
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
continuePrompt: {
value: (x: string, y?: string) => y ?? x,
default: () => prompts.continue,
},
end: {
value: (x?: string | null, y?: string | null) => y ?? x,
default: () => end,
@ -75,7 +80,7 @@ export const getDefaultGraphState = ({
},
refinePrompt: {
value: (x: string, y?: string) => y ?? x,
default: () => getDefaultRefinePrompt(),
default: () => prompts.refine,
},
maxGenerationAttempts: {
value: (x: number, y?: number) => y ?? x,

View file

@ -14,6 +14,7 @@ export interface GraphState {
anonymizedAlerts: Document[];
combinedGenerations: string;
combinedRefinements: string;
continuePrompt: string;
end?: string | null;
errors: string[];
filter?: Record<string, unknown> | null;

View file

@ -18,6 +18,7 @@ import type { InferenceServerStart } from '@kbn/inference-plugin/server';
import { AnalyticsServiceSetup } from '@kbn/core-analytics-server';
import { TelemetryParams } from '@kbn/langchain/server/tracers/telemetry/telemetry_tracer';
import type { LlmTasksPluginStart } from '@kbn/llm-tasks-plugin/server';
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
import { ResponseBody } from '../types';
import type { AssistantTool } from '../../../types';
import { AIAssistantKnowledgeBaseDataClient } from '../../../ai_assistant_data_clients/knowledge_base';
@ -57,6 +58,7 @@ export interface AgentExecutorParams<T extends boolean> {
onLlmResponse?: OnLlmResponse;
request: KibanaRequest<unknown, unknown, ExecuteConnectorRequestBody>;
response?: KibanaResponseFactory;
savedObjectsClient: SavedObjectsClientContract;
size?: number;
systemPrompt?: string;
telemetry: AnalyticsServiceSetup;

View file

@ -14,6 +14,9 @@ import type { Logger } from '@kbn/logging';
import { BaseMessage } from '@langchain/core/messages';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { ConversationResponse, Replacements } from '@kbn/elastic-assistant-common';
import { PublicMethodsOf } from '@kbn/utility-types';
import { ActionsClient } from '@kbn/actions-plugin/server';
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
import { AgentState, NodeParamsBase } from './types';
import { AssistantDataClients } from '../../executors/types';
@ -30,10 +33,12 @@ import { NodeType } from './constants';
export const DEFAULT_ASSISTANT_GRAPH_ID = 'Default Security Assistant Graph';
export interface GetDefaultAssistantGraphParams {
actionsClient: PublicMethodsOf<ActionsClient>;
agentRunnable: AgentRunnableSequence;
dataClients?: AssistantDataClients;
createLlmInstance: () => BaseChatModel;
logger: Logger;
savedObjectsClient: SavedObjectsClientContract;
signal?: AbortSignal;
tools: StructuredTool[];
replacements: Replacements;
@ -42,10 +47,12 @@ export interface GetDefaultAssistantGraphParams {
export type DefaultAssistantGraph = ReturnType<typeof getDefaultAssistantGraph>;
export const getDefaultAssistantGraph = ({
actionsClient,
agentRunnable,
dataClients,
createLlmInstance,
logger,
savedObjectsClient,
// some chat models (bedrock) require a signal to be passed on agent invoke rather than the signal passed to the chat model
signal,
tools,
@ -97,6 +104,10 @@ export const getDefaultAssistantGraph = ({
value: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
connectorId: {
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
conversation: {
value: (x: ConversationResponse | undefined, y?: ConversationResponse | undefined) =>
y ?? x,
@ -114,7 +125,9 @@ export const getDefaultAssistantGraph = ({
// Default node parameters
const nodeParams: NodeParamsBase = {
actionsClient,
logger,
savedObjectsClient,
};
// Put together a new graph using default state from above

View file

@ -56,6 +56,7 @@ describe('streamGraph', () => {
input: 'input',
responseLanguage: 'English',
llmType: 'openai',
connectorId: '123',
},
logger: mockLogger,
onLlmResponse: mockOnLlmResponse,

View file

@ -18,6 +18,7 @@ import {
createStructuredChatAgent,
createToolCallingAgent,
} from 'langchain/agents';
import { savedObjectsClientMock } from '@kbn/core-saved-objects-api-server-mocks';
jest.mock('./graph');
jest.mock('./helpers');
jest.mock('langchain/agents');
@ -41,6 +42,13 @@ describe('callAssistantGraph', () => {
},
};
const savedObjectsClient = savedObjectsClientMock.create();
savedObjectsClient.find = jest.fn().mockResolvedValue({
page: 1,
per_page: 20,
total: 0,
saved_objects: [],
});
const defaultParams = {
actionsClient: actionsClientMock.create(),
alertsIndexPattern: 'test-pattern',
@ -60,6 +68,7 @@ describe('callAssistantGraph', () => {
onNewReplacements: jest.fn(),
replacements: [],
request: mockRequest,
savedObjectsClient,
size: 1,
systemPrompt: 'test-prompt',
telemetry: {},
@ -167,6 +176,12 @@ describe('callAssistantGraph', () => {
});
it('creates OpenAIToolsAgent for inference llmType', async () => {
defaultParams.actionsClient.get = jest.fn().mockResolvedValue({
config: {
provider: 'elastic',
providerConfig: { model_id: 'rainbow-sprinkles' },
},
});
const params = { ...defaultParams, llmType: 'inference' };
await callAssistantGraph(params);

View file

@ -14,11 +14,14 @@ import {
} from 'langchain/agents';
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
import { TelemetryTracer } from '@kbn/langchain/server/tracers/telemetry';
import { promptGroupId } from '../../../prompt/local_prompt_object';
import { getModelOrOss } from '../../../prompt/helpers';
import { getPrompt, promptDictionary } from '../../../prompt';
import { getLlmClass } from '../../../../routes/utils';
import { EsAnonymizationFieldsSchema } from '../../../../ai_assistant_data_clients/anonymization_fields/types';
import { AssistantToolParams } from '../../../../types';
import { AgentExecutor } from '../../executors/types';
import { formatPrompt, formatPromptStructured, systemPrompts } from './prompts';
import { formatPrompt, formatPromptStructured } from './prompts';
import { GraphInputs } from './types';
import { getDefaultAssistantGraph } from './graph';
import { invokeGraph, streamGraph } from './helpers';
@ -44,6 +47,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
onNewReplacements,
replacements,
request,
savedObjectsClient,
size,
systemPrompt,
telemetry,
@ -130,29 +134,36 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
}
}
const defaultSystemPrompt = await getPrompt({
actionsClient,
connectorId,
model: getModelOrOss(llmType, isOssModel, request.body.model),
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: llmType,
savedObjectsClient,
});
const agentRunnable =
isOpenAI || llmType === 'inference'
? await createOpenAIToolsAgent({
llm: createLlmInstance(),
tools,
prompt: formatPrompt(systemPrompts.openai, systemPrompt),
prompt: formatPrompt(defaultSystemPrompt, systemPrompt),
streamRunnable: isStream,
})
: llmType && ['bedrock', 'gemini'].includes(llmType)
? await createToolCallingAgent({
llm: createLlmInstance(),
tools,
prompt:
llmType === 'bedrock'
? formatPrompt(systemPrompts.bedrock, systemPrompt)
: formatPrompt(systemPrompts.gemini, systemPrompt),
prompt: formatPrompt(defaultSystemPrompt, systemPrompt),
streamRunnable: isStream,
})
: // used with OSS models
await createStructuredChatAgent({
llm: createLlmInstance(),
tools,
prompt: formatPromptStructured(systemPrompts.structuredChat, systemPrompt),
prompt: formatPromptStructured(defaultSystemPrompt, systemPrompt),
streamRunnable: isStream,
});
@ -174,6 +185,8 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
// we need to pass it like this or streaming does not work for bedrock
createLlmInstance,
logger,
actionsClient,
savedObjectsClient,
tools,
replacements,
// some chat models (bedrock) require a signal to be passed on agent invoke rather than the signal passed to the chat model
@ -182,6 +195,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
const inputs: GraphInputs = {
responseLanguage,
conversationId,
connectorId,
llmType,
isStream,
isOssModel,

View file

@ -7,7 +7,8 @@
import { RunnableConfig } from '@langchain/core/runnables';
import { AgentRunnableSequence } from 'langchain/dist/agents/agent';
import { formatLatestUserMessage } from '../prompts';
import { promptGroupId } from '../../../../prompt/local_prompt_object';
import { getPrompt, promptDictionary } from '../../../../prompt';
import { AgentState, NodeParamsBase } from '../types';
import { NodeType } from '../constants';
import { AIAssistantKnowledgeBaseDataClient } from '../../../../../ai_assistant_data_clients/knowledge_base';
@ -34,7 +35,9 @@ const NO_KNOWLEDGE_HISTORY = '[No existing knowledge history]';
* @param kbDataClient - Data client for accessing the Knowledge Base on behalf of the current user
*/
export async function runAgent({
actionsClient,
logger,
savedObjectsClient,
state,
agentRunnable,
config,
@ -43,6 +46,17 @@ export async function runAgent({
logger.debug(() => `${NodeType.AGENT}: Node state:\n${JSON.stringify(state, null, 2)}`);
const knowledgeHistory = await kbDataClient?.getRequiredKnowledgeBaseDocumentEntries();
const userPrompt =
state.llmType === 'gemini'
? await getPrompt({
actionsClient,
connectorId: state.connectorId,
promptId: promptDictionary.userPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'gemini',
savedObjectsClient,
})
: '';
const agentOutcome = await agentRunnable
.withConfig({ tags: [AGENT_NODE_TAG], signal: config?.signal })
.invoke(
@ -54,7 +68,7 @@ export async function runAgent({
: NO_KNOWLEDGE_HISTORY
}`,
// prepend any user prompt (gemini)
input: formatLatestUserMessage(state.input, state.llmType),
input: `${userPrompt}${state.input}`,
chat_history: state.messages, // TODO: Message de-dupe with ...state spread
},
config

View file

@ -1,78 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
// TODO determine whether or not system prompts should be i18n'd
const YOU_ARE_A_HELPFUL_EXPERT_ASSISTANT =
'You are a security analyst and expert in resolving security incidents. Your role is to assist by answering questions about Elastic Security.';
const IF_YOU_DONT_KNOW_THE_ANSWER = 'Do not answer questions unrelated to Elastic Security.';
export const KNOWLEDGE_HISTORY =
'If available, use the Knowledge History provided to try and answer the question. If not provided, you can try and query for additional knowledge via the KnowledgeBaseRetrievalTool.';
export const DEFAULT_SYSTEM_PROMPT = `${YOU_ARE_A_HELPFUL_EXPERT_ASSISTANT} ${IF_YOU_DONT_KNOW_THE_ANSWER} ${KNOWLEDGE_HISTORY}`;
// system prompt from @afirstenberg
const BASE_GEMINI_PROMPT =
'You are an assistant that is an expert at using tools and Elastic Security, doing your best to use these tools to answer questions or follow instructions. It is very important to use tools to answer the question or follow the instructions rather than coming up with your own answer. Tool calls are good. Sometimes you may need to make several tool calls to accomplish the task or get an answer to the question that was asked. Use as many tool calls as necessary.';
const KB_CATCH =
'If the knowledge base tool gives empty results, do your best to answer the question from the perspective of an expert security analyst.';
export const GEMINI_SYSTEM_PROMPT = `${BASE_GEMINI_PROMPT} ${KB_CATCH}`;
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from NaturalLanguageESQLTool as is. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response.`;
export const GEMINI_USER_PROMPT = `Now, always using the tools at your disposal, step by step, come up with a response to this request:\n\n`;
export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. ${KNOWLEDGE_HISTORY} You have access to the following tools:
{tools}
The tool action_input should ALWAYS follow the tool JSON schema args.
Valid "action" values: "Final Answer" or {tool_names}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).
Provide only ONE action per $JSON_BLOB, as shown:
\`\`\`
{{
"action": $TOOL_NAME,
"action_input": $TOOL_INPUT
}}
\`\`\`
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
\`\`\`
$JSON_BLOB
\`\`\`
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
\`\`\`
{{
"action": "Final Answer",
"action_input": "Final response to human"}}
Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`;

View file

@ -6,13 +6,6 @@
*/
import { ChatPromptTemplate } from '@langchain/core/prompts';
import {
BEDROCK_SYSTEM_PROMPT,
DEFAULT_SYSTEM_PROMPT,
GEMINI_SYSTEM_PROMPT,
GEMINI_USER_PROMPT,
STRUCTURED_SYSTEM_PROMPT,
} from './nodes/translations';
export const formatPrompt = (prompt: string, additionalPrompt?: string) =>
ChatPromptTemplate.fromMessages([
@ -23,20 +16,6 @@ export const formatPrompt = (prompt: string, additionalPrompt?: string) =>
['placeholder', '{agent_scratchpad}'],
]);
export const systemPrompts = {
openai: DEFAULT_SYSTEM_PROMPT,
bedrock: `${DEFAULT_SYSTEM_PROMPT} ${BEDROCK_SYSTEM_PROMPT}`,
// The default prompt overwhelms gemini, do not prepend
gemini: GEMINI_SYSTEM_PROMPT,
structuredChat: STRUCTURED_SYSTEM_PROMPT,
};
export const openAIFunctionAgentPrompt = formatPrompt(systemPrompts.openai);
export const bedrockToolCallingAgentPrompt = formatPrompt(systemPrompts.bedrock);
export const geminiToolCallingAgentPrompt = formatPrompt(systemPrompts.gemini);
export const formatPromptStructured = (prompt: string, additionalPrompt?: string) =>
ChatPromptTemplate.fromMessages([
['system', additionalPrompt ? `${prompt}\n\n${additionalPrompt}` : prompt],
@ -47,18 +26,3 @@ export const formatPromptStructured = (prompt: string, additionalPrompt?: string
'{input}\n\n{agent_scratchpad}\n\n(reminder to respond in a JSON blob no matter what)',
],
]);
export const structuredChatAgentPrompt = formatPromptStructured(systemPrompts.structuredChat);
/**
* If Gemini is the llmType,
* Adds a user prompt for the latest message in a conversation
* @param prompt
* @param llmType
*/
export const formatLatestUserMessage = (prompt: string, llmType?: string): string => {
if (llmType === 'gemini') {
return `${GEMINI_USER_PROMPT}${prompt}`;
}
return prompt;
};

View file

@ -9,6 +9,9 @@ import { BaseMessage } from '@langchain/core/messages';
import { AgentAction, AgentFinish, AgentStep } from '@langchain/core/agents';
import type { Logger } from '@kbn/logging';
import { ConversationResponse } from '@kbn/elastic-assistant-common';
import { PublicMethodsOf } from '@kbn/utility-types';
import { ActionsClient } from '@kbn/actions-plugin/server';
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
export interface AgentStateBase {
agentOutcome?: AgentAction | AgentFinish;
@ -16,6 +19,7 @@ export interface AgentStateBase {
}
export interface GraphInputs {
connectorId: string;
conversationId?: string;
llmType?: string;
isStream?: boolean;
@ -34,10 +38,13 @@ export interface AgentState extends AgentStateBase {
isOssModel: boolean;
llmType: string;
responseLanguage: string;
connectorId: string;
conversation: ConversationResponse | undefined;
conversationId: string;
}
export interface NodeParamsBase {
actionsClient: PublicMethodsOf<ActionsClient>;
logger: Logger;
savedObjectsClient: SavedObjectsClientContract;
}

View file

@ -0,0 +1,520 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { getPrompt, getPromptsByGroupId } from './get_prompt';
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
import { ActionsClient } from '@kbn/actions-plugin/server';
import { BEDROCK_SYSTEM_PROMPT, DEFAULT_SYSTEM_PROMPT, GEMINI_USER_PROMPT } from './prompts';
import { promptDictionary, promptGroupId } from './local_prompt_object';
jest.mock('@kbn/core-saved-objects-api-server');
jest.mock('@kbn/actions-plugin/server');
const defaultConnector = {
id: 'mock',
name: 'Mock',
isPreconfigured: false,
isDeprecated: false,
isSystemAction: false,
actionTypeId: '.inference',
};
describe('get_prompt', () => {
let savedObjectsClient: jest.Mocked<SavedObjectsClientContract>;
let actionsClient: jest.Mocked<ActionsClient>;
beforeEach(() => {
jest.clearAllMocks();
savedObjectsClient = {
find: jest.fn().mockResolvedValue({
page: 1,
per_page: 20,
total: 3,
saved_objects: [
{
type: 'security-ai-prompt',
id: '977b39b8-5bb9-4530-9a39-7aa7084fb5c0',
attributes: {
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'openai',
model: 'gpt-4o',
description: 'Default prompt for AI Assistant system prompt.',
prompt: {
default: 'Hello world this is a system prompt',
},
},
references: [],
managed: false,
updated_at: '2025-01-22T18:44:35.271Z',
updated_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
created_at: '2025-01-22T18:44:35.271Z',
created_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
version: 'Wzk0MiwxXQ==',
coreMigrationVersion: '8.8.0',
score: 0.13353139,
},
{
type: 'security-ai-prompt',
id: 'd6dacb9b-1029-4c4c-85e1-e4f97b31c7f4',
attributes: {
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'openai',
description: 'Default prompt for AI Assistant system prompt.',
prompt: {
default: 'Hello world this is a system prompt no model',
},
},
references: [],
managed: false,
updated_at: '2025-01-22T19:11:48.806Z',
updated_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
created_at: '2025-01-22T19:11:48.806Z',
created_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
version: 'Wzk4MCwxXQ==',
coreMigrationVersion: '8.8.0',
score: 0.13353139,
},
{
type: 'security-ai-prompt',
id: 'd6dacb9b-1029-4c4c-85e1-e4f97b31c7f4',
attributes: {
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'bedrock',
description: 'Default prompt for AI Assistant system prompt.',
prompt: {
default: 'Hello world this is a system prompt for bedrock',
},
},
references: [],
managed: false,
updated_at: '2025-01-22T19:11:48.806Z',
updated_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
created_at: '2025-01-22T19:11:48.806Z',
created_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
version: 'Wzk4MCwxXQ==',
coreMigrationVersion: '8.8.0',
score: 0.13353139,
},
{
type: 'security-ai-prompt',
id: 'd6dacb9b-1029-4c4c-85e1-e4f97b31c7f4',
attributes: {
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'bedrock',
model: 'us.anthropic.claude-3-5-sonnet-20240620-v1:0',
description: 'Default prompt for AI Assistant system prompt.',
prompt: {
default: 'Hello world this is a system prompt for bedrock claude-3-5-sonnet',
},
},
references: [],
managed: false,
updated_at: '2025-01-22T19:11:48.806Z',
updated_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
created_at: '2025-01-22T19:11:48.806Z',
created_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
version: 'Wzk4MCwxXQ==',
coreMigrationVersion: '8.8.0',
score: 0.13353139,
},
{
type: 'security-ai-prompt',
id: 'da530fad-87ce-49c3-a088-08073e5034d6',
attributes: {
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
description: 'Default prompt for AI Assistant system prompt.',
prompt: {
default: 'Hello world this is a system prompt no model, no provider',
},
},
references: [],
managed: false,
updated_at: '2025-01-22T19:12:12.911Z',
updated_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
created_at: '2025-01-22T19:12:12.911Z',
created_by: 'u_mGBROF_q5bmFCATbLXAcCwKa0k8JvONAwSruelyKA5E_0',
version: 'Wzk4MiwxXQ==',
coreMigrationVersion: '8.8.0',
score: 0.13353139,
},
],
}),
} as unknown as jest.Mocked<SavedObjectsClientContract>;
actionsClient = {
get: jest.fn().mockResolvedValue({
config: {
provider: 'openai',
providerConfig: { model_id: 'gpt-4o' },
},
}),
} as unknown as jest.Mocked<ActionsClient>;
});
describe('getPrompt', () => {
it('returns the prompt matching provider and model', async () => {
const result = await getPrompt({
savedObjectsClient,
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'openai',
model: 'gpt-4o',
actionsClient,
connectorId: 'connector-123',
});
expect(actionsClient.get).not.toHaveBeenCalled();
expect(result).toBe('Hello world this is a system prompt');
});
it('returns the prompt matching provider when model does not have a match', async () => {
const result = await getPrompt({
savedObjectsClient,
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'openai',
model: 'gpt-4o-mini',
actionsClient,
connectorId: 'connector-123',
});
expect(actionsClient.get).not.toHaveBeenCalled();
expect(result).toBe('Hello world this is a system prompt no model');
});
it('returns the prompt matching provider when model is not provided', async () => {
const result = await getPrompt({
savedObjectsClient,
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'openai',
actionsClient,
connectorId: 'connector-123',
});
expect(actionsClient.get).toHaveBeenCalled();
expect(result).toBe('Hello world this is a system prompt no model');
});
it('returns the default prompt when there is no match on provider', async () => {
const result = await getPrompt({
savedObjectsClient,
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'badone',
actionsClient,
connectorId: 'connector-123',
});
expect(result).toBe('Hello world this is a system prompt no model, no provider');
});
it('defaults provider to bedrock if provider is "inference"', async () => {
actionsClient.get.mockResolvedValue(defaultConnector);
const result = await getPrompt({
savedObjectsClient,
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'inference',
model: 'gpt-4o',
actionsClient,
connectorId: 'connector-123',
});
expect(result).toBe('Hello world this is a system prompt for bedrock');
});
it('returns the expected prompt from when provider is "elastic" and model matches in elasticModelDictionary', async () => {
actionsClient.get.mockResolvedValue({
...defaultConnector,
config: {
provider: 'elastic',
providerConfig: { model_id: 'rainbow-sprinkles' },
},
});
const result = await getPrompt({
savedObjectsClient,
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'inference',
actionsClient,
connectorId: 'connector-123',
});
expect(result).toBe('Hello world this is a system prompt for bedrock claude-3-5-sonnet');
});
it('returns the bedrock prompt when provider is "elastic" but model does not match elasticModelDictionary', async () => {
actionsClient.get.mockResolvedValue({
...defaultConnector,
config: {
provider: 'elastic',
providerConfig: { model_id: 'unknown-model' },
},
});
const result = await getPrompt({
savedObjectsClient,
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'inference',
actionsClient,
connectorId: 'connector-123',
});
expect(result).toBe('Hello world this is a system prompt for bedrock');
});
it('returns the model prompt when no prompts are found and model is provided', async () => {
savedObjectsClient.find.mockResolvedValue({
page: 1,
per_page: 20,
total: 0,
saved_objects: [],
});
const result = await getPrompt({
savedObjectsClient,
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
actionsClient,
provider: 'bedrock',
connectorId: 'connector-123',
});
expect(result).toBe(BEDROCK_SYSTEM_PROMPT);
});
it('returns the default prompt when no prompts are found', async () => {
savedObjectsClient.find.mockResolvedValue({
page: 1,
per_page: 20,
total: 0,
saved_objects: [],
});
const result = await getPrompt({
savedObjectsClient,
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
actionsClient,
connectorId: 'connector-123',
});
expect(result).toBe(DEFAULT_SYSTEM_PROMPT);
});
it('throws an error when no prompts are found', async () => {
savedObjectsClient.find.mockResolvedValue({
page: 1,
per_page: 20,
total: 0,
saved_objects: [],
});
await expect(
getPrompt({
savedObjectsClient,
promptId: 'nonexistent-prompt',
promptGroupId: 'nonexistent-group',
actionsClient,
connectorId: 'connector-123',
})
).rejects.toThrow(
'Prompt not found for promptId: nonexistent-prompt and promptGroupId: nonexistent-group'
);
});
it('handles invalid connector configuration gracefully when provider is "inference"', async () => {
actionsClient.get.mockResolvedValue({
...defaultConnector,
config: {},
});
const result = await getPrompt({
savedObjectsClient,
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'inference',
actionsClient,
connectorId: 'connector-123',
});
expect(result).toBe('Hello world this is a system prompt for bedrock');
});
it('retrieves the connector when no model or provider is provided', async () => {
actionsClient.get.mockResolvedValue({
...defaultConnector,
actionTypeId: '.bedrock',
config: {
defaultModel: 'us.anthropic.claude-3-5-sonnet-20240620-v1:0',
},
});
const result = await getPrompt({
savedObjectsClient,
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
actionsClient,
connectorId: 'connector-123',
});
expect(actionsClient.get).toHaveBeenCalled();
expect(result).toBe('Hello world this is a system prompt for bedrock claude-3-5-sonnet');
});
it('retrieves the connector when no model is provided', async () => {
actionsClient.get.mockResolvedValue({
...defaultConnector,
actionTypeId: '.bedrock',
config: {
defaultModel: 'us.anthropic.claude-3-5-sonnet-20240620-v1:0',
},
});
const result = await getPrompt({
savedObjectsClient,
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'bedrock',
actionsClient,
connectorId: 'connector-123',
});
expect(actionsClient.get).toHaveBeenCalled();
expect(result).toBe('Hello world this is a system prompt for bedrock claude-3-5-sonnet');
});
});
describe('getPromptsByGroupId', () => {
it('returns prompts matching the provided promptIds', async () => {
const result = await getPromptsByGroupId({
savedObjectsClient,
promptIds: [promptDictionary.systemPrompt],
promptGroupId: promptGroupId.aiAssistant,
provider: 'openai',
model: 'gpt-4o',
actionsClient,
connectorId: 'connector-123',
});
expect(savedObjectsClient.find).toHaveBeenCalledWith({
type: 'security-ai-prompt',
searchFields: ['promptGroupId'],
search: promptGroupId.aiAssistant,
});
expect(result).toEqual([
{
promptId: promptDictionary.systemPrompt,
prompt: 'Hello world this is a system prompt',
},
]);
});
it('returns prompts matching the provided promptIds for gemini', async () => {
const result = await getPromptsByGroupId({
savedObjectsClient,
promptIds: [promptDictionary.systemPrompt, promptDictionary.userPrompt],
promptGroupId: promptGroupId.aiAssistant,
provider: 'gemini',
actionsClient,
connectorId: 'connector-123',
});
expect(result).toEqual([
{
promptId: promptDictionary.systemPrompt,
prompt: 'Hello world this is a system prompt no model, no provider',
},
{
promptId: promptDictionary.userPrompt,
prompt: GEMINI_USER_PROMPT,
},
]);
});
it('returns prompts matching the provided promptIds when connector is given', async () => {
const result = await getPromptsByGroupId({
savedObjectsClient,
promptIds: [promptDictionary.systemPrompt, promptDictionary.userPrompt],
promptGroupId: promptGroupId.aiAssistant,
connector: {
actionTypeId: '.gemini',
config: {
defaultModel: 'gemini-1.5-pro-002',
},
id: 'connector-123',
name: 'Gemini',
isPreconfigured: false,
isDeprecated: false,
isSystemAction: false,
},
actionsClient,
connectorId: 'connector-123',
});
expect(result).toEqual([
{
promptId: promptDictionary.systemPrompt,
prompt: 'Hello world this is a system prompt no model, no provider',
},
{
promptId: promptDictionary.userPrompt,
prompt: GEMINI_USER_PROMPT,
},
]);
});
it('returns prompts matching the provided promptIds when inference connector is given', async () => {
const result = await getPromptsByGroupId({
savedObjectsClient,
promptIds: [promptDictionary.systemPrompt],
promptGroupId: promptGroupId.aiAssistant,
connector: {
actionTypeId: '.inference',
config: {
provider: 'elastic',
providerConfig: { model_id: 'rainbow-sprinkles' },
},
id: 'connector-123',
name: 'Inference',
isPreconfigured: false,
isDeprecated: false,
isSystemAction: false,
},
actionsClient,
connectorId: 'connector-123',
});
expect(result).toEqual([
{
promptId: promptDictionary.systemPrompt,
prompt: 'Hello world this is a system prompt for bedrock',
},
]);
});
it('throws an error when a prompt is missing', async () => {
savedObjectsClient.find.mockResolvedValue({
page: 1,
per_page: 20,
total: 0,
saved_objects: [],
});
await expect(
getPromptsByGroupId({
savedObjectsClient,
promptIds: [promptDictionary.systemPrompt, 'fake-id'],
promptGroupId: promptGroupId.aiAssistant,
actionsClient,
connectorId: 'connector-123',
})
).rejects.toThrow('Prompt not found for promptId: fake-id and promptGroupId: aiAssistant');
});
});
});

View file

@ -0,0 +1,222 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
import { PublicMethodsOf } from '@kbn/utility-types';
import { ActionsClient } from '@kbn/actions-plugin/server';
import type { Connector } from '@kbn/actions-plugin/server/application/connector/types';
import { elasticModelDictionary } from '@kbn/inference-common';
import { Prompt } from './types';
import { localPrompts } from './local_prompt_object';
import { getLlmType } from '../../routes/utils';
import { promptSavedObjectType } from '../../../common/constants';
interface GetPromptArgs {
actionsClient: PublicMethodsOf<ActionsClient>;
connector?: Connector;
connectorId: string;
model?: string;
promptId: string;
promptGroupId: string;
provider?: string;
savedObjectsClient: SavedObjectsClientContract;
}
interface GetPromptsByGroupIdArgs extends Omit<GetPromptArgs, 'promptId'> {
promptGroupId: string;
promptIds: string[];
}
type PromptArray = Array<{ promptId: string; prompt: string }>;
/**
* Get prompts by feature (promptGroupId)
* provide either model + provider or connector to avoid additional calls to get connector
* @param actionsClient - actions client
* @param connector - connector, provide if available. No need to provide model and provider in this case
* @param connectorId - connector id
* @param model - model. No need to provide if connector provided
* @param promptGroupId - feature id, should be common across promptIds
* @param promptIds - prompt ids with shared promptGroupId
* @param provider - provider. No need to provide if connector provided
* @param savedObjectsClient - saved objects client
*/
export const getPromptsByGroupId = async ({
actionsClient,
connector,
connectorId,
model: providedModel,
promptGroupId,
promptIds,
provider: providedProvider,
savedObjectsClient,
}: GetPromptsByGroupIdArgs): Promise<PromptArray> => {
const { provider, model } = await resolveProviderAndModel({
providedProvider,
providedModel,
connectorId,
actionsClient,
providedConnector: connector,
});
const prompts = await savedObjectsClient.find<Prompt>({
type: promptSavedObjectType,
searchFields: ['promptGroupId'],
search: promptGroupId,
});
const promptsOnly = prompts?.saved_objects.map((p) => p.attributes) ?? [];
return promptIds.map((promptId) => {
const prompt = findPromptEntry({
prompts: promptsOnly.filter((p) => p.promptId === promptId) ?? [],
promptId,
promptGroupId,
provider,
model,
});
if (!prompt) {
throw new Error(
`Prompt not found for promptId: ${promptId} and promptGroupId: ${promptGroupId}`
);
}
return {
promptId,
prompt,
};
});
};
/**
* Get prompt by promptId
* provide either model + provider or connector to avoid additional calls to get connector
* @param actionsClient - actions client
* @param connector - connector, provide if available. No need to provide model and provider in this case
* @param connectorId - connector id
* @param model - model. No need to provide if connector provided
* @param promptId - prompt id
* @param promptGroupId - feature id, should be common across promptIds
* @param provider - provider. No need to provide if connector provided
* @param savedObjectsClient - saved objects client
*/
export const getPrompt = async ({
actionsClient,
connector,
connectorId,
model: providedModel,
promptGroupId,
promptId,
provider: providedProvider,
savedObjectsClient,
}: GetPromptArgs): Promise<string> => {
const { provider, model } = await resolveProviderAndModel({
providedProvider,
providedModel,
connectorId,
actionsClient,
providedConnector: connector,
});
const prompts = await savedObjectsClient.find<Prompt>({
type: promptSavedObjectType,
filter: `${promptSavedObjectType}.attributes.promptId: "${promptId}" AND ${promptSavedObjectType}.attributes.promptGroupId: "${promptGroupId}"`,
fields: ['provider', 'model', 'prompt'],
});
const prompt = findPromptEntry({
prompts: prompts?.saved_objects.map((p) => p.attributes) ?? [],
promptId,
promptGroupId,
provider,
model,
});
if (!prompt) {
throw new Error(
`Prompt not found for promptId: ${promptId} and promptGroupId: ${promptGroupId}`
);
}
return prompt;
};
const resolveProviderAndModel = async ({
providedProvider,
providedModel,
connectorId,
actionsClient,
providedConnector,
}: {
providedProvider: string | undefined;
providedModel: string | undefined;
connectorId: string;
actionsClient: PublicMethodsOf<ActionsClient>;
providedConnector?: Connector;
}): Promise<{ provider?: string; model?: string }> => {
let model = providedModel;
let provider = providedProvider;
if (!provider || !model || provider === 'inference') {
const connector = providedConnector ?? (await actionsClient.get({ id: connectorId }));
if (provider === 'inference' && connector.config) {
provider = connector.config.provider || provider;
model = connector.config.providerConfig?.model_id || model;
if (provider === 'elastic' && model) {
provider = elasticModelDictionary[model]?.provider || 'inference';
model = elasticModelDictionary[model]?.model;
}
} else if (connector.config) {
provider = provider || getLlmType(connector.actionTypeId);
model = model || connector.config.defaultModel;
}
}
return { provider: provider === 'inference' ? 'bedrock' : provider, model };
};
const findPrompt = ({
prompts,
conditions,
}: {
prompts: Array<{ provider?: string; model?: string; prompt: { default: string } }>;
conditions: Array<(prompt: { provider?: string; model?: string }) => boolean>;
}): string | undefined => {
for (const condition of conditions) {
const match = prompts.find(condition);
if (match) return match.prompt.default;
}
return undefined;
};
const findPromptEntry = ({
prompts,
promptId,
promptGroupId,
provider,
model,
}: {
prompts: Prompt[];
promptId: string;
promptGroupId: string;
provider?: string;
model?: string;
}): string | undefined => {
const conditions = [
(prompt: { provider?: string; model?: string }) =>
prompt.provider === provider && prompt.model === model,
(prompt: { provider?: string; model?: string }) =>
prompt.provider === provider && !prompt.model,
(prompt: { provider?: string; model?: string }) => !prompt.provider && !prompt.model,
];
return (
findPrompt({ prompts, conditions }) ??
findPrompt({
prompts: localPrompts.filter(
(p) => p.promptId === promptId && p.promptGroupId === promptGroupId
),
conditions,
})
);
};

View file

@ -0,0 +1,20 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
/**
* use oss as model when using openai and oss
* else default to given model
* if no model exists, let undefined and resolveProviderAndModel logic will determine the model from connector
* @param llmType
* @param isOssModel
* @param model
*/
export const getModelOrOss = (
llmType?: string,
isOssModel?: boolean,
model?: string
): string | undefined => (llmType === 'openai' && isOssModel ? 'oss' : model);

View file

@ -0,0 +1,9 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export { getPrompt, getPromptsByGroupId } from './get_prompt';
export { promptDictionary } from './local_prompt_object';

View file

@ -0,0 +1,157 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Prompt } from './types';
import {
ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
ATTACK_DISCOVERY_GENERATION_INSIGHTS,
ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
ATTACK_DISCOVERY_GENERATION_TITLE,
ATTACK_DISCOVERY_CONTINUE,
ATTACK_DISCOVERY_DEFAULT,
ATTACK_DISCOVERY_REFINE,
BEDROCK_SYSTEM_PROMPT,
DEFAULT_SYSTEM_PROMPT,
GEMINI_SYSTEM_PROMPT,
GEMINI_USER_PROMPT,
STRUCTURED_SYSTEM_PROMPT,
} from './prompts';
export const promptGroupId = {
attackDiscovery: 'attackDiscovery',
aiAssistant: 'aiAssistant',
};
export const promptDictionary = {
systemPrompt: `systemPrompt`,
userPrompt: `userPrompt`,
attackDiscoveryDefault: `default`,
attackDiscoveryRefine: `refine`,
attackDiscoveryContinue: `continue`,
attackDiscoveryDetailsMarkdown: `detailsMarkdown`,
attackDiscoveryEntitySummaryMarkdown: `entitySummaryMarkdown`,
attackDiscoveryMitreAttackTactics: `mitreAttackTactics`,
attackDiscoverySummaryMarkdown: `summaryMarkdown`,
attackDiscoveryGenerationTitle: `generationTitle`,
attackDiscoveryGenerationInsights: `generationInsights`,
};
export const localPrompts: Prompt[] = [
{
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'openai',
prompt: {
default: DEFAULT_SYSTEM_PROMPT,
},
},
{
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
prompt: {
default: DEFAULT_SYSTEM_PROMPT,
},
},
{
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'bedrock',
prompt: {
default: BEDROCK_SYSTEM_PROMPT,
},
},
{
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'gemini',
prompt: {
default: GEMINI_SYSTEM_PROMPT,
},
},
{
promptId: promptDictionary.systemPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'openai',
model: 'oss',
prompt: {
default: STRUCTURED_SYSTEM_PROMPT,
},
},
{
promptId: promptDictionary.userPrompt,
promptGroupId: promptGroupId.aiAssistant,
provider: 'gemini',
prompt: {
default: GEMINI_USER_PROMPT,
},
},
{
promptId: promptDictionary.attackDiscoveryDefault,
promptGroupId: promptGroupId.attackDiscovery,
prompt: {
default: ATTACK_DISCOVERY_DEFAULT,
},
},
{
promptId: promptDictionary.attackDiscoveryRefine,
promptGroupId: promptGroupId.attackDiscovery,
prompt: {
default: ATTACK_DISCOVERY_REFINE,
},
},
{
promptId: promptDictionary.attackDiscoveryContinue,
promptGroupId: promptGroupId.attackDiscovery,
prompt: {
default: ATTACK_DISCOVERY_CONTINUE,
},
},
{
promptId: promptDictionary.attackDiscoveryDetailsMarkdown,
promptGroupId: promptGroupId.attackDiscovery,
prompt: {
default: ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN,
},
},
{
promptId: promptDictionary.attackDiscoveryEntitySummaryMarkdown,
promptGroupId: promptGroupId.attackDiscovery,
prompt: {
default: ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN,
},
},
{
promptId: promptDictionary.attackDiscoveryMitreAttackTactics,
promptGroupId: promptGroupId.attackDiscovery,
prompt: {
default: ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS,
},
},
{
promptId: promptDictionary.attackDiscoverySummaryMarkdown,
promptGroupId: promptGroupId.attackDiscovery,
prompt: {
default: ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN,
},
},
{
promptId: promptDictionary.attackDiscoveryGenerationTitle,
promptGroupId: promptGroupId.attackDiscovery,
prompt: {
default: ATTACK_DISCOVERY_GENERATION_TITLE,
},
},
{
promptId: promptDictionary.attackDiscoveryGenerationInsights,
promptGroupId: promptGroupId.attackDiscovery,
prompt: {
default: ATTACK_DISCOVERY_GENERATION_INSIGHTS,
},
},
];

View file

@ -0,0 +1,128 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export const KNOWLEDGE_HISTORY =
'If available, use the Knowledge History provided to try and answer the question. If not provided, you can try and query for additional knowledge via the KnowledgeBaseRetrievalTool.';
export const DEFAULT_SYSTEM_PROMPT = `You are a security analyst and expert in resolving security incidents. Your role is to assist by answering questions about Elastic Security. Do not answer questions unrelated to Elastic Security. ${KNOWLEDGE_HISTORY}`;
// system prompt from @afirstenberg
const BASE_GEMINI_PROMPT =
'You are an assistant that is an expert at using tools and Elastic Security, doing your best to use these tools to answer questions or follow instructions. It is very important to use tools to answer the question or follow the instructions rather than coming up with your own answer. Tool calls are good. Sometimes you may need to make several tool calls to accomplish the task or get an answer to the question that was asked. Use as many tool calls as necessary.';
const KB_CATCH =
'If the knowledge base tool gives empty results, do your best to answer the question from the perspective of an expert security analyst.';
export const GEMINI_SYSTEM_PROMPT = `${BASE_GEMINI_PROMPT} ${KB_CATCH}`;
export const BEDROCK_SYSTEM_PROMPT = `Use tools as often as possible, as they have access to the latest data and syntax. Always return value from NaturalLanguageESQLTool as is. Never return <thinking> tags in the response, but make sure to include <result> tags content in the response. Do not reflect on the quality of the returned search results in your response.`;
export const GEMINI_USER_PROMPT = `Now, always using the tools at your disposal, step by step, come up with a response to this request:\n\n`;
export const STRUCTURED_SYSTEM_PROMPT = `Respond to the human as helpfully and accurately as possible. ${KNOWLEDGE_HISTORY} You have access to the following tools:
{tools}
The tool action_input should ALWAYS follow the tool JSON schema args.
Valid "action" values: "Final Answer" or {tool_names}
Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input strictly adhering to the tool JSON schema args).
Provide only ONE action per $JSON_BLOB, as shown:
\`\`\`
{{
"action": $TOOL_NAME,
"action_input": $TOOL_INPUT
}}
\`\`\`
Follow this format:
Question: input question to answer
Thought: consider previous and subsequent steps
Action:
\`\`\`
$JSON_BLOB
\`\`\`
Observation: action result
... (repeat Thought/Action/Observation N times)
Thought: I know what to respond
Action:
\`\`\`
{{
"action": "Final Answer",
"action_input": "Final response to human"}}
Begin! Reminder to ALWAYS respond with a valid json blob of a single action with no additional output. When using tools, ALWAYS input the expected JSON schema args. Your answer will be parsed as JSON, so never use double quotes within the output and instead use backticks. Single quotes may be used, such as apostrophes. Response format is Action:\`\`\`$JSON_BLOB\`\`\`then Observation`;
export const ATTACK_DISCOVERY_DEFAULT =
"You are a cyber security analyst tasked with analyzing security events from Elastic Security to identify and report on potential cyber attacks or progressions. Your report should focus on high-risk incidents that could severely impact the organization, rather than isolated alerts. Present your findings in a way that can be easily understood by anyone, regardless of their technical expertise, as if you were briefing the CISO. Break down your response into sections based on timing, hosts, and users involved. When correlating alerts, use kibana.alert.original_time when it's available, otherwise use @timestamp. Include appropriate context about the affected hosts and users. Describe how the attack progression might have occurred and, if feasible, attribute it to known threat groups. Prioritize high and critical alerts, but include lower-severity alerts if desired. In the description field, provide as much detail as possible, in a bulleted list explaining any attack progressions. Accuracy is of utmost importance. You MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds).";
export const ATTACK_DISCOVERY_REFINE = `You previously generated the following insights, but sometimes they represent the same attack.
Combine the insights below, when they represent the same attack; leave any insights that are not combined unchanged:`;
export const ATTACK_DISCOVERY_CONTINUE = `Continue exactly where you left off in the JSON output below, generating only the additional JSON output when it's required to complete your work. The additional JSON output MUST ALWAYS follow these rules:
1) it MUST conform to the schema above, because it will be checked against the JSON schema
2) it MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds), because it will be parsed as JSON
3) it MUST NOT repeat any the previous output, because that would prevent partial results from being combined
4) it MUST NOT restart from the beginning, because that would prevent partial results from being combined
5) it MUST NOT be prefixed or suffixed with additional text outside of the JSON, because that would prevent it from being combined and parsed as JSON:
`;
const SYNTAX = '{{ field.name fieldValue1 fieldValue2 fieldValueN }}';
const GOOD_SYNTAX_EXAMPLES =
'Examples of CORRECT syntax (includes field names and values): {{ host.name hostNameValue }} {{ user.name userNameValue }} {{ source.ip sourceIpValue }}';
const BAD_SYNTAX_EXAMPLES =
'Examples of INCORRECT syntax (bad, because the field names are not included): {{ hostNameValue }} {{ userNameValue }} {{ sourceIpValue }}';
const RECONNAISSANCE = 'Reconnaissance';
const INITIAL_ACCESS = 'Initial Access';
const EXECUTION = 'Execution';
const PERSISTENCE = 'Persistence';
const PRIVILEGE_ESCALATION = 'Privilege Escalation';
const DISCOVERY = 'Discovery';
const LATERAL_MOVEMENT = 'Lateral Movement';
const COMMAND_AND_CONTROL = 'Command and Control';
const EXFILTRATION = 'Exfiltration';
const MITRE_ATTACK_TACTICS = [
RECONNAISSANCE,
INITIAL_ACCESS,
EXECUTION,
PERSISTENCE,
PRIVILEGE_ESCALATION,
DISCOVERY,
LATERAL_MOVEMENT,
COMMAND_AND_CONTROL,
EXFILTRATION,
] as const;
export const ATTACK_DISCOVERY_GENERATION_DETAILS_MARKDOWN = `A detailed insight with markdown, where each markdown bullet contains a description of what happened that reads like a story of the attack as it played out and always uses special ${SYNTAX} syntax for field names and values from the source data. ${GOOD_SYNTAX_EXAMPLES} ${BAD_SYNTAX_EXAMPLES}`;
export const ATTACK_DISCOVERY_GENERATION_ENTITY_SUMMARY_MARKDOWN = `A short (no more than a sentence) summary of the insight featuring only the host.name and user.name fields (when they are applicable), using the same ${SYNTAX} syntax`;
export const ATTACK_DISCOVERY_GENERATION_MITRE_ATTACK_TACTICS = `An array of MITRE ATT&CK tactic for the insight, using one of the following values: ${MITRE_ATTACK_TACTICS.join(
','
)}`;
export const ATTACK_DISCOVERY_GENERATION_SUMMARY_MARKDOWN = `A markdown summary of insight, using the same ${SYNTAX} syntax`;
export const ATTACK_DISCOVERY_GENERATION_TITLE =
'A short, no more than 7 words, title for the insight, NOT formatted with special syntax or markdown. This must be as brief as possible.';
export const ATTACK_DISCOVERY_GENERATION_INSIGHTS = `Insights with markdown that always uses special ${SYNTAX} syntax for field names and values from the source data. ${GOOD_SYNTAX_EXAMPLES} ${BAD_SYNTAX_EXAMPLES}`;

View file

@ -0,0 +1,52 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { SavedObjectsType } from '@kbn/core/server';
import { promptSavedObjectType } from '../../../common/constants';
export const promptSavedObjectMappings: SavedObjectsType['mappings'] = {
dynamic: false,
properties: {
description: {
type: 'text',
},
promptId: {
// represent unique prompt
type: 'keyword',
},
promptGroupId: {
// represents unique groups of prompts
type: 'keyword',
},
provider: {
type: 'keyword',
},
model: {
type: 'keyword',
},
prompt: {
properties: {
// English is default
default: {
type: 'text',
},
// optionally, add ISO 639 two-letter language code to support more translations
},
},
},
};
export const promptType: SavedObjectsType = {
name: promptSavedObjectType,
hidden: false,
management: {
importableAndExportable: true,
visibleInManagement: false,
},
namespaceType: 'agnostic',
mappings: promptSavedObjectMappings,
};

View file

@ -0,0 +1,17 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
export interface Prompt {
promptId: string;
promptGroupId: string;
prompt: {
default: string;
};
provider?: string;
model?: string;
description?: string;
}

View file

@ -10,6 +10,7 @@ import { PluginInitializerContext, CoreStart, Plugin, Logger } from '@kbn/core/s
import { AssistantFeatures } from '@kbn/elastic-assistant-common';
import { ReplaySubject, type Subject } from 'rxjs';
import { MlPluginSetup } from '@kbn/ml-plugin/server';
import { initSavedObjects } from './saved_objects';
import { events } from './lib/telemetry/event_based_telemetry';
import {
AssistantTool,
@ -55,6 +56,8 @@ export class ElasticAssistantPlugin
) {
this.logger.debug('elasticAssistant: Setup');
initSavedObjects(core.savedObjects);
this.assistantService = new AIAssistantService({
logger: this.logger.get('service'),
ml: plugins.ml,

View file

@ -7,7 +7,7 @@
import type { ActionsClient } from '@kbn/actions-plugin/server';
import { ElasticsearchClient } from '@kbn/core-elasticsearch-server';
import { Logger } from '@kbn/core/server';
import { Logger, SavedObjectsClientContract } from '@kbn/core/server';
import { ApiConfig, AttackDiscovery, Replacements } from '@kbn/elastic-assistant-common';
import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen';
import { ActionsClientLlm } from '@kbn/langchain/server';
@ -23,6 +23,7 @@ import {
import { GraphState } from '../../../../../lib/attack_discovery/graphs/default_attack_discovery_graph/types';
import { throwIfErrorCountsExceeded } from '../throw_if_error_counts_exceeded';
import { getLlmType } from '../../../../utils';
import { getAttackDiscoveryPrompts } from '../../../../../lib/attack_discovery/graphs/default_attack_discovery_graph/nodes/helpers/prompts';
export const invokeAttackDiscoveryGraph = async ({
actionsClient,
@ -38,6 +39,7 @@ export const invokeAttackDiscoveryGraph = async ({
latestReplacements,
logger,
onNewReplacements,
savedObjectsClient,
size,
start,
}: {
@ -54,6 +56,7 @@ export const invokeAttackDiscoveryGraph = async ({
latestReplacements: Replacements;
logger: Logger;
onNewReplacements: (newReplacements: Replacements) => void;
savedObjectsClient: SavedObjectsClientContract;
start?: string;
size: number;
}): Promise<{
@ -89,6 +92,15 @@ export const invokeAttackDiscoveryGraph = async ({
throw new Error('LLM is required for attack discoveries');
}
const attackDiscoveryPrompts = await getAttackDiscoveryPrompts({
actionsClient,
connectorId: apiConfig.connectorId,
// if in future oss has different prompt, add it as model here
model,
provider: llmType,
savedObjectsClient,
});
const graph = getDefaultAttackDiscoveryGraph({
alertsIndexPattern,
anonymizationFields,
@ -98,6 +110,7 @@ export const invokeAttackDiscoveryGraph = async ({
llm,
logger,
onNewReplacements,
prompts: attackDiscoveryPrompts,
replacements: latestReplacements,
size,
start,

View file

@ -66,6 +66,7 @@ export const postAttackDiscoveryRoute = (
const assistantContext = await context.elasticAssistant;
const logger: Logger = assistantContext.logger;
const telemetry = assistantContext.telemetry;
const savedObjectsClient = assistantContext.savedObjectsClient;
try {
// get the actions plugin start contract from the request context:
@ -144,6 +145,7 @@ export const postAttackDiscoveryRoute = (
latestReplacements,
logger,
onNewReplacements,
savedObjectsClient,
size,
start,
})

View file

@ -106,6 +106,7 @@ export const chatCompleteRoute = (
const connector = connectors.length > 0 ? connectors[0] : undefined;
actionTypeId = connector?.actionTypeId ?? '.gen-ai';
const isOssModel = isOpenSourceModel(connector);
const savedObjectsClient = ctx.elasticAssistant.savedObjectsClient;
// replacements
const anonymizationFieldsRes =
@ -221,6 +222,7 @@ export const chatCompleteRoute = (
response,
telemetry,
responseLanguage: request.body.responseLanguage,
savedObjectsClient,
...(productDocsAvailable ? { llmTasks: ctx.elasticAssistant.llmTasks } : {}),
});
} catch (err) {

View file

@ -30,6 +30,14 @@ import {
createToolCallingAgent,
} from 'langchain/agents';
import { omit } from 'lodash/fp';
import { promptGroupId } from '../../lib/prompt/local_prompt_object';
import { getModelOrOss } from '../../lib/prompt/helpers';
import { getAttackDiscoveryPrompts } from '../../lib/attack_discovery/graphs/default_attack_discovery_graph/nodes/helpers/prompts';
import {
formatPrompt,
formatPromptStructured,
} from '../../lib/langchain/graphs/default_assistant_graph/prompts';
import { getPrompt, promptDictionary } from '../../lib/prompt';
import { buildResponse } from '../../lib/build_response';
import { AssistantDataClients } from '../../lib/langchain/executors/types';
import { AssistantToolParams, ElasticAssistantRequestHandlerContext, GetElser } from '../../types';
@ -42,12 +50,6 @@ import {
DefaultAssistantGraph,
getDefaultAssistantGraph,
} from '../../lib/langchain/graphs/default_assistant_graph/graph';
import {
bedrockToolCallingAgentPrompt,
geminiToolCallingAgentPrompt,
openAIFunctionAgentPrompt,
structuredChatAgentPrompt,
} from '../../lib/langchain/graphs/default_assistant_graph/prompts';
import { getLlmClass, getLlmType, isOpenSourceModel } from '../utils';
import { getGraphsFromNames } from './get_graphs_from_names';
@ -95,6 +97,7 @@ export const postEvaluateRoute = (
const actions = ctx.elasticAssistant.actions;
const logger = assistantContext.logger.get('evaluate');
const abortSignal = getRequestAbortedSignal(request.events.aborted$);
const savedObjectsClient = ctx.elasticAssistant.savedObjectsClient;
// Perform license, authenticated user and evaluation FF checks
const checkResponse = performChecks({
@ -176,6 +179,20 @@ export const postEvaluateRoute = (
ids: connectorIds,
throwIfSystemAction: false,
});
const connectorsWithPrompts = await Promise.all(
connectors.map(async (connector) => {
const prompts = await getAttackDiscoveryPrompts({
actionsClient,
connectorId: connector.id,
connector,
savedObjectsClient,
});
return {
...connector,
prompts,
};
})
);
// Fetch any tools registered to the security assistant
const assistantTools = assistantContext.getRegisteredTools(DEFAULT_PLUGIN_NAME);
@ -190,7 +207,7 @@ export const postEvaluateRoute = (
actionsClient,
alertsIndexPattern,
attackDiscoveryGraphs,
connectors,
connectors: connectorsWithPrompts,
connectorTimeout: CONNECTOR_TIMEOUT,
datasetName,
esClient,
@ -217,6 +234,7 @@ export const postEvaluateRoute = (
graph: DefaultAssistantGraph;
llmType: string | undefined;
isOssModel: boolean | undefined;
connectorId: string;
}> = await Promise.all(
connectors.map(async (connector) => {
const llmType = getLlmType(connector.actionTypeId);
@ -293,31 +311,40 @@ export const postEvaluateRoute = (
(tool) => tool.getTool(assistantToolParams) ?? []
);
const defaultSystemPrompt = await getPrompt({
actionsClient,
connector,
connectorId: connector.id,
model: getModelOrOss(llmType, isOssModel),
promptGroupId: promptGroupId.aiAssistant,
promptId: promptDictionary.systemPrompt,
provider: llmType,
savedObjectsClient,
});
const agentRunnable = isOpenAI
? await createOpenAIFunctionsAgent({
llm,
tools,
prompt: openAIFunctionAgentPrompt,
prompt: formatPrompt(defaultSystemPrompt),
streamRunnable: false,
})
: llmType && ['bedrock', 'gemini'].includes(llmType)
? createToolCallingAgent({
llm,
tools,
prompt:
llmType === 'bedrock'
? bedrockToolCallingAgentPrompt
: geminiToolCallingAgentPrompt,
prompt: formatPrompt(defaultSystemPrompt),
streamRunnable: false,
})
: await createStructuredChatAgent({
llm,
tools,
prompt: structuredChatAgentPrompt,
prompt: formatPromptStructured(defaultSystemPrompt),
streamRunnable: false,
});
return {
connectorId: connector.id,
name: `${runName} - ${connector.name}`,
llmType,
isOssModel,
@ -326,6 +353,8 @@ export const postEvaluateRoute = (
dataClients,
createLlmInstance,
logger,
actionsClient,
savedObjectsClient,
tools,
replacements: {},
}),
@ -334,7 +363,7 @@ export const postEvaluateRoute = (
);
// Run an evaluation for each graph so they show up separately (resulting in each dataset run grouped by connector)
await asyncForEach(graphs, async ({ name, graph, llmType, isOssModel }) => {
await asyncForEach(graphs, async ({ name, graph, llmType, isOssModel, connectorId }) => {
// Wrapper function for invoking the graph (to parse different input/output formats)
const predict = async (input: { input: string }) => {
logger.debug(`input:\n ${JSON.stringify(input, null, 2)}`);
@ -342,6 +371,7 @@ export const postEvaluateRoute = (
const r = await graph.invoke(
{
input: input.input,
connectorId,
conversationId: undefined,
responseLanguage: 'English',
llmType,

View file

@ -12,6 +12,7 @@ import {
KibanaRequest,
KibanaResponseFactory,
Logger,
SavedObjectsClientContract,
} from '@kbn/core/server';
import { StreamResponseWithHeaders } from '@kbn/ml-response-stream/server';
@ -235,6 +236,7 @@ export interface LangChainExecuteParams {
getElser: GetElser;
response: KibanaResponseFactory;
responseLanguage?: string;
savedObjectsClient: SavedObjectsClientContract;
systemPrompt?: string;
}
export const langChainExecute = async ({
@ -258,6 +260,7 @@ export const langChainExecute = async ({
response,
responseLanguage,
isStream = true,
savedObjectsClient,
systemPrompt,
}: LangChainExecuteParams) => {
// Fetch any tools registered by the request's originating plugin
@ -316,6 +319,7 @@ export const langChainExecute = async ({
request,
replacements,
responseLanguage,
savedObjectsClient,
size: request.body.size,
systemPrompt,
telemetry,

View file

@ -99,6 +99,7 @@ export const postActionsConnectorExecuteRoute = (
// get the actions plugin start contract from the request context:
const actions = ctx.elasticAssistant.actions;
const inference = ctx.elasticAssistant.inference;
const savedObjectsClient = ctx.elasticAssistant.savedObjectsClient;
const productDocsAvailable =
(await ctx.elasticAssistant.llmTasks.retrieveDocumentationAvailable()) ?? false;
const actionsClient = await actions.getActionsClientWithRequest(request);
@ -153,6 +154,7 @@ export const postActionsConnectorExecuteRoute = (
request,
response,
telemetry,
savedObjectsClient,
systemPrompt,
...(productDocsAvailable ? { llmTasks: ctx.elasticAssistant.llmTasks } : {}),
});

View file

@ -80,7 +80,7 @@ export class RequestContextFactory implements IRequestContextFactory {
},
llmTasks: startPlugins.llmTasks,
inference: startPlugins.inference,
savedObjectsClient: coreStart.savedObjects.getScopedClient(request),
telemetry: core.analytics,
// Note: modelIdOverride is used here to enable setting up the KB using a different ELSER model, which

View file

@ -0,0 +1,18 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import type { CoreSetup } from '@kbn/core/server';
import { promptType } from './lib/prompt/saved_object_mappings';
export const initSavedObjects = (savedObjects: CoreSetup['savedObjects']) => {
try {
savedObjects.registerType(promptType);
} catch (e) {
// implementation intends to fall back to reasonable defaults when the saved objects are unavailable
// do not block the plugin from starting
}
};

View file

@ -9,23 +9,24 @@ import type {
PluginSetupContract as ActionsPluginSetup,
PluginStartContract as ActionsPluginStart,
} from '@kbn/actions-plugin/server';
import type {
import {
AuthenticatedUser,
CoreRequestHandlerContext,
CoreSetup,
AnalyticsServiceSetup,
CustomRequestHandlerContext,
ElasticsearchClient,
IRouter,
KibanaRequest,
Logger,
AuditLogger,
SavedObjectsClientContract,
} from '@kbn/core/server';
import type { LlmTasksPluginStart } from '@kbn/llm-tasks-plugin/server';
import { type MlPluginSetup } from '@kbn/ml-plugin/server';
import { DynamicStructuredTool, Tool } from '@langchain/core/tools';
import { SpacesPluginSetup, SpacesPluginStart } from '@kbn/spaces-plugin/server';
import { TaskManagerSetupContract } from '@kbn/task-manager-plugin/server';
import { ElasticsearchClient } from '@kbn/core/server';
import {
AttackDiscoveryPostRequestBody,
DefendInsightsPostRequestBody,
@ -140,6 +141,7 @@ export interface ElasticAssistantApiRequestHandlerContext {
getAIAssistantAnonymizationFieldsDataClient: () => Promise<AIAssistantDataClient | null>;
llmTasks: LlmTasksPluginStart;
inference: InferenceServerStart;
savedObjectsClient: SavedObjectsClientContract;
telemetry: AnalyticsServiceSetup;
}
/**

View file

@ -53,6 +53,8 @@
"@kbn/core-analytics-server",
"@kbn/llm-tasks-plugin",
"@kbn/product-doc-base-plugin",
"@kbn/core-saved-objects-api-server-mocks",
"@kbn/inference-common"
],
"exclude": [
"target/**/*",

View file

@ -171,6 +171,7 @@ export const generateFleetPackageInfo = (): PackageInfo => {
ml_module: [],
security_rule: [],
tag: [],
security_ai_prompt: [],
osquery_pack_asset: [],
osquery_saved_query: [],
},