mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 01:38:56 -04:00
# Backport This will backport the following commits from `main` to `8.16`: - [[Security Solution] [Attack discovery] Additional Attack discovery tests (#199659)](https://github.com/elastic/kibana/pull/199659) <!--- Backport version: 8.9.8 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) <!--BACKPORT [{"author":{"name":"Andrew Macri","email":"andrew.macri@elastic.co"},"sourceCommit":{"committedDate":"2024-11-13T17:37:54Z","message":"[Security Solution] [Attack discovery] Additional Attack discovery tests (#199659)\n\n### [Security Solution] [Attack discovery] Additional Attack discovery tests\r\n\r\nThis PR adds additional unit test coverage to Attack discovery.","sha":"53d4580a8959a9e4b166df4e4a4cc83de61f7928","branchLabelMapping":{"^v9.0.0$":"main","^v8.17.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v9.0.0","Team: SecuritySolution","Team:Security Generative AI","backport:version","v8.17.0","v8.16.1"],"number":199659,"url":"https://github.com/elastic/kibana/pull/199659","mergeCommit":{"message":"[Security Solution] [Attack discovery] Additional Attack discovery tests (#199659)\n\n### [Security Solution] [Attack discovery] Additional Attack discovery tests\r\n\r\nThis PR adds additional unit test coverage to Attack discovery.","sha":"53d4580a8959a9e4b166df4e4a4cc83de61f7928"}},"sourceBranch":"main","suggestedTargetBranches":["8.16"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","labelRegex":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/199659","number":199659,"mergeCommit":{"message":"[Security Solution] [Attack discovery] Additional Attack discovery tests (#199659)\n\n### [Security Solution] [Attack discovery] Additional Attack discovery tests\r\n\r\nThis PR adds additional unit test coverage to Attack discovery.","sha":"53d4580a8959a9e4b166df4e4a4cc83de61f7928"}},{"branch":"8.x","label":"v8.17.0","labelRegex":"^v8.17.0$","isSourceBranch":false,"url":"https://github.com/elastic/kibana/pull/200061","number":200061,"state":"OPEN"},{"branch":"8.16","label":"v8.16.1","labelRegex":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"}]}] BACKPORT-->
This commit is contained in:
parent
ecfb386daa
commit
58dd0876a4
33 changed files with 2182 additions and 40 deletions
|
@ -0,0 +1,86 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { fireEvent, render, screen } from '@testing-library/react';
|
||||
import React from 'react';
|
||||
|
||||
import { AlertsRange } from './alerts_range';
|
||||
import {
|
||||
MAX_LATEST_ALERTS,
|
||||
MIN_LATEST_ALERTS,
|
||||
} from '../assistant/settings/alerts_settings/alerts_settings';
|
||||
import { KnowledgeBaseConfig } from '../assistant/types';
|
||||
|
||||
const nonDefaultMin = MIN_LATEST_ALERTS + 5000;
|
||||
const nonDefaultMax = nonDefaultMin + 5000;
|
||||
|
||||
describe('AlertsRange', () => {
|
||||
beforeEach(() => jest.clearAllMocks());
|
||||
|
||||
it('renders the expected default min alerts', () => {
|
||||
render(<AlertsRange value={200} />);
|
||||
|
||||
expect(screen.getByText(`${MIN_LATEST_ALERTS}`)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders the expected NON-default min alerts', () => {
|
||||
render(
|
||||
<AlertsRange maxAlerts={nonDefaultMax} minAlerts={nonDefaultMin} value={nonDefaultMin} />
|
||||
);
|
||||
|
||||
expect(screen.getByText(`${nonDefaultMin}`)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders the expected default max alerts', () => {
|
||||
render(<AlertsRange value={200} />);
|
||||
|
||||
expect(screen.getByText(`${MAX_LATEST_ALERTS}`)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('renders the expected NON-default max alerts', () => {
|
||||
render(
|
||||
<AlertsRange maxAlerts={nonDefaultMax} minAlerts={nonDefaultMin} value={nonDefaultMax} />
|
||||
);
|
||||
|
||||
expect(screen.getByText(`${nonDefaultMax}`)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calls onChange when the range value changes', () => {
|
||||
const mockOnChange = jest.fn();
|
||||
render(<AlertsRange onChange={mockOnChange} value={MIN_LATEST_ALERTS} />);
|
||||
|
||||
fireEvent.click(screen.getByText(`${MAX_LATEST_ALERTS}`));
|
||||
|
||||
expect(mockOnChange).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('calls setUpdatedKnowledgeBaseSettings with the expected arguments', () => {
|
||||
const mockSetUpdatedKnowledgeBaseSettings = jest.fn();
|
||||
const knowledgeBase: KnowledgeBaseConfig = { latestAlerts: 150 };
|
||||
|
||||
render(
|
||||
<AlertsRange
|
||||
knowledgeBase={knowledgeBase}
|
||||
setUpdatedKnowledgeBaseSettings={mockSetUpdatedKnowledgeBaseSettings}
|
||||
value={MIN_LATEST_ALERTS}
|
||||
/>
|
||||
);
|
||||
|
||||
fireEvent.click(screen.getByText(`${MAX_LATEST_ALERTS}`));
|
||||
|
||||
expect(mockSetUpdatedKnowledgeBaseSettings).toHaveBeenCalledWith({
|
||||
...knowledgeBase,
|
||||
latestAlerts: MAX_LATEST_ALERTS,
|
||||
});
|
||||
});
|
||||
|
||||
it('renders with the correct initial value', () => {
|
||||
render(<AlertsRange value={250} />);
|
||||
|
||||
expect(screen.getByTestId('alertsRange')).toHaveValue('250');
|
||||
});
|
||||
});
|
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { AnonymizationFieldResponse } from '@kbn/elastic-assistant-common/impl/schemas/anonymization_fields/bulk_crud_anonymization_fields_route.gen';
|
||||
|
||||
export const getMockAnonymizationFieldResponse = (): AnonymizationFieldResponse[] => [
|
||||
{
|
||||
id: '6UDO45IBoEQSo_rIK1EW',
|
||||
timestamp: '2024-10-31T18:19:52.468Z',
|
||||
field: '_id',
|
||||
allowed: true,
|
||||
anonymized: false,
|
||||
createdAt: '2024-10-31T18:19:52.468Z',
|
||||
namespace: 'default',
|
||||
},
|
||||
{
|
||||
id: '6kDO45IBoEQSo_rIK1EW',
|
||||
timestamp: '2024-10-31T18:19:52.468Z',
|
||||
field: '@timestamp',
|
||||
allowed: true,
|
||||
anonymized: false,
|
||||
createdAt: '2024-10-31T18:19:52.468Z',
|
||||
namespace: 'default',
|
||||
},
|
||||
{
|
||||
id: '60DO45IBoEQSo_rIK1EW',
|
||||
timestamp: '2024-10-31T18:19:52.468Z',
|
||||
field: 'cloud.availability_zone',
|
||||
allowed: true,
|
||||
anonymized: false,
|
||||
createdAt: '2024-10-31T18:19:52.468Z',
|
||||
namespace: 'default',
|
||||
},
|
||||
{
|
||||
id: '_EDO45IBoEQSo_rIK1EW',
|
||||
timestamp: '2024-10-31T18:19:52.468Z',
|
||||
field: 'host.name',
|
||||
allowed: true,
|
||||
anonymized: true,
|
||||
createdAt: '2024-10-31T18:19:52.468Z',
|
||||
namespace: 'default',
|
||||
},
|
||||
{
|
||||
id: 'SkDO45IBoEQSo_rIK1IW',
|
||||
timestamp: '2024-10-31T18:19:52.468Z',
|
||||
field: 'user.name',
|
||||
allowed: true,
|
||||
anonymized: true,
|
||||
createdAt: '2024-10-31T18:19:52.468Z',
|
||||
namespace: 'default',
|
||||
},
|
||||
{
|
||||
id: 'TUDO45IBoEQSo_rIK1IW',
|
||||
timestamp: '2024-10-31T18:19:52.468Z',
|
||||
field: 'user.target.name',
|
||||
allowed: true,
|
||||
anonymized: true,
|
||||
createdAt: '2024-10-31T18:19:52.468Z',
|
||||
namespace: 'default',
|
||||
},
|
||||
];
|
|
@ -12,7 +12,7 @@ describe('getAlertsContextPrompt', () => {
|
|||
it('generates the correct prompt', () => {
|
||||
const anonymizedAlerts = ['Alert 1', 'Alert 2', 'Alert 3'];
|
||||
|
||||
const expected = `You are a cyber security analyst tasked with analyzing security events from Elastic Security to identify and report on potential cyber attacks or progressions. Your report should focus on high-risk incidents that could severely impact the organization, rather than isolated alerts. Present your findings in a way that can be easily understood by anyone, regardless of their technical expertise, as if you were briefing the CISO. Break down your response into sections based on timing, hosts, and users involved. When correlating alerts, use kibana.alert.original_time when it's available, otherwise use @timestamp. Include appropriate context about the affected hosts and users. Describe how the attack progression might have occurred and, if feasible, attribute it to known threat groups. Prioritize high and critical alerts, but include lower-severity alerts if desired. In the description field, provide as much detail as possible, in a bulleted list explaining any attack progressions. Accuracy is of utmost importance. You MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds).
|
||||
const expected = `${getDefaultAttackDiscoveryPrompt()}
|
||||
|
||||
Use context from the following alerts to provide insights:
|
||||
|
||||
|
|
|
@ -5,8 +5,8 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { Logger } from '@kbn/core/server';
|
||||
import type { ActionsClientLlm } from '@kbn/langchain/server';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import { FakeLLM } from '@langchain/core/utils/testing';
|
||||
|
||||
import { getGenerateNode } from '.';
|
||||
|
@ -16,7 +16,15 @@ import {
|
|||
} from '../../../../evaluation/__mocks__/mock_anonymized_alerts';
|
||||
import { getAnonymizedAlertsFromState } from './helpers/get_anonymized_alerts_from_state';
|
||||
import { getChainWithFormatInstructions } from '../helpers/get_chain_with_format_instructions';
|
||||
import { getDefaultAttackDiscoveryPrompt } from '../helpers/get_default_attack_discovery_prompt';
|
||||
import { getDefaultRefinePrompt } from '../refine/helpers/get_default_refine_prompt';
|
||||
import { GraphState } from '../../types';
|
||||
import {
|
||||
getParsedAttackDiscoveriesMock,
|
||||
getRawAttackDiscoveriesMock,
|
||||
} from '../../../../../../__mocks__/raw_attack_discoveries';
|
||||
|
||||
const attackDiscoveryTimestamp = '2024-10-11T17:55:59.702Z';
|
||||
|
||||
jest.mock('../helpers/get_chain_with_format_instructions', () => {
|
||||
const mockInvoke = jest.fn().mockResolvedValue('');
|
||||
|
@ -27,19 +35,21 @@ jest.mock('../helpers/get_chain_with_format_instructions', () => {
|
|||
invoke: mockInvoke,
|
||||
},
|
||||
formatInstructions: ['mock format instructions'],
|
||||
llmType: 'fake',
|
||||
llmType: 'openai',
|
||||
mockInvoke, // <-- added for testing
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
const mockLogger = loggerMock.create();
|
||||
const mockLogger = {
|
||||
debug: (x: Function) => x(),
|
||||
} as unknown as Logger;
|
||||
|
||||
let mockLlm: ActionsClientLlm;
|
||||
|
||||
const initialGraphState: GraphState = {
|
||||
attackDiscoveries: null,
|
||||
attackDiscoveryPrompt:
|
||||
"You are a cyber security analyst tasked with analyzing security events from Elastic Security to identify and report on potential cyber attacks or progressions. Your report should focus on high-risk incidents that could severely impact the organization, rather than isolated alerts. Present your findings in a way that can be easily understood by anyone, regardless of their technical expertise, as if you were briefing the CISO. Break down your response into sections based on timing, hosts, and users involved. When correlating alerts, use kibana.alert.original_time when it's available, otherwise use @timestamp. Include appropriate context about the affected hosts and users. Describe how the attack progression might have occurred and, if feasible, attribute it to known threat groups. Prioritize high and critical alerts, but include lower-severity alerts if desired. In the description field, provide as much detail as possible, in a bulleted list explaining any attack progressions. Accuracy is of utmost importance. You MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds).",
|
||||
attackDiscoveryPrompt: getDefaultAttackDiscoveryPrompt(),
|
||||
anonymizedAlerts: [...mockAnonymizedAlerts],
|
||||
combinedGenerations: '',
|
||||
combinedRefinements: '',
|
||||
|
@ -51,8 +61,7 @@ const initialGraphState: GraphState = {
|
|||
maxHallucinationFailures: 5,
|
||||
maxRepeatedGenerations: 3,
|
||||
refinements: [],
|
||||
refinePrompt:
|
||||
'You previously generated the following insights, but sometimes they represent the same attack.\n\nCombine the insights below, when they represent the same attack; leave any insights that are not combined unchanged:',
|
||||
refinePrompt: getDefaultRefinePrompt(),
|
||||
replacements: {
|
||||
...mockAnonymizedAlertsReplacements,
|
||||
},
|
||||
|
@ -63,11 +72,18 @@ describe('getGenerateNode', () => {
|
|||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
jest.useFakeTimers();
|
||||
jest.setSystemTime(new Date(attackDiscoveryTimestamp));
|
||||
|
||||
mockLlm = new FakeLLM({
|
||||
response: JSON.stringify({}, null, 2),
|
||||
response: '',
|
||||
}) as unknown as ActionsClientLlm;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('returns a function', () => {
|
||||
const generateNode = getGenerateNode({
|
||||
llm: mockLlm,
|
||||
|
@ -77,9 +93,8 @@ describe('getGenerateNode', () => {
|
|||
expect(typeof generateNode).toBe('function');
|
||||
});
|
||||
|
||||
it('invokes the chain with the alerts from state and format instructions', async () => {
|
||||
// @ts-expect-error
|
||||
const { mockInvoke } = getChainWithFormatInstructions(mockLlm);
|
||||
it('invokes the chain with the expected alerts from state and formatting instructions', async () => {
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlm).chain.invoke as jest.Mock;
|
||||
|
||||
const generateNode = getGenerateNode({
|
||||
llm: mockLlm,
|
||||
|
@ -100,4 +115,214 @@ ${getAnonymizedAlertsFromState(initialGraphState).join('\n\n')}
|
|||
`,
|
||||
});
|
||||
});
|
||||
|
||||
it('removes the surrounding json from the response', async () => {
|
||||
const response =
|
||||
'You asked for some JSON, here it is:\n```json\n{"key": "value"}\n```\nI hope that works for you.';
|
||||
|
||||
const mockLlmWithResponse = new FakeLLM({ response }) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(response);
|
||||
|
||||
const generateNode = getGenerateNode({
|
||||
llm: mockLlmWithResponse,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
const state = await generateNode(initialGraphState);
|
||||
|
||||
expect(state).toEqual({
|
||||
...initialGraphState,
|
||||
combinedGenerations: '{"key": "value"}',
|
||||
errors: [
|
||||
'generate node is unable to parse (fake) response from attempt 0; (this may be an incomplete response from the model): [\n {\n "code": "invalid_type",\n "expected": "array",\n "received": "undefined",\n "path": [\n "insights"\n ],\n "message": "Required"\n }\n]',
|
||||
],
|
||||
generationAttempts: 1,
|
||||
generations: ['{"key": "value"}'],
|
||||
});
|
||||
});
|
||||
|
||||
it('handles hallucinations', async () => {
|
||||
const hallucinatedResponse =
|
||||
'tactics like **Credential Access**, **Command and Control**, and **Persistence**.",\n "entitySummaryMarkdown": "Malware detected on host **{{ host.name hostNameValue }}**';
|
||||
|
||||
const mockLlmWithHallucination = new FakeLLM({
|
||||
response: hallucinatedResponse,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithHallucination).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(hallucinatedResponse);
|
||||
|
||||
const generateNode = getGenerateNode({
|
||||
llm: mockLlmWithHallucination,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
...initialGraphState,
|
||||
combinedGenerations: '{"key": "value"}',
|
||||
generationAttempts: 1,
|
||||
generations: ['{"key": "value"}'],
|
||||
};
|
||||
|
||||
const state = await generateNode(withPreviousGenerations);
|
||||
|
||||
expect(state).toEqual({
|
||||
...withPreviousGenerations,
|
||||
combinedGenerations: '', // <-- reset
|
||||
generationAttempts: 2, // <-- incremented
|
||||
generations: [], // <-- reset
|
||||
hallucinationFailures: 1, // <-- incremented
|
||||
});
|
||||
});
|
||||
|
||||
it('discards previous generations and starts over when the maxRepeatedGenerations limit is reached', async () => {
|
||||
const repeatedResponse = 'gen1';
|
||||
|
||||
const mockLlmWithRepeatedGenerations = new FakeLLM({
|
||||
response: repeatedResponse,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithRepeatedGenerations).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(repeatedResponse);
|
||||
|
||||
const generateNode = getGenerateNode({
|
||||
llm: mockLlmWithRepeatedGenerations,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
...initialGraphState,
|
||||
combinedGenerations: 'gen1gen1',
|
||||
generationAttempts: 2,
|
||||
generations: ['gen1', 'gen1'],
|
||||
};
|
||||
|
||||
const state = await generateNode(withPreviousGenerations);
|
||||
|
||||
expect(state).toEqual({
|
||||
...withPreviousGenerations,
|
||||
combinedGenerations: '',
|
||||
generationAttempts: 3, // <-- incremented
|
||||
generations: [],
|
||||
});
|
||||
});
|
||||
|
||||
it('combines the response with the previous generations', async () => {
|
||||
const response = 'gen1';
|
||||
|
||||
const mockLlmWithResponse = new FakeLLM({
|
||||
response,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(response);
|
||||
|
||||
const generateNode = getGenerateNode({
|
||||
llm: mockLlmWithResponse,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
...initialGraphState,
|
||||
combinedGenerations: 'gen0',
|
||||
generationAttempts: 1,
|
||||
generations: ['gen0'],
|
||||
};
|
||||
|
||||
const state = await generateNode(withPreviousGenerations);
|
||||
|
||||
expect(state).toEqual({
|
||||
...withPreviousGenerations,
|
||||
combinedGenerations: 'gen0gen1',
|
||||
errors: [
|
||||
'generate node is unable to parse (fake) response from attempt 1; (this may be an incomplete response from the model): SyntaxError: Unexpected token \'g\', "gen0gen1" is not valid JSON',
|
||||
],
|
||||
generationAttempts: 2,
|
||||
generations: ['gen0', 'gen1'],
|
||||
});
|
||||
});
|
||||
|
||||
it('returns unrefined results when combined responses pass validation', async () => {
|
||||
// split the response into two parts to simulate a valid response
|
||||
const splitIndex = 100; // arbitrary index
|
||||
const firstResponse = getRawAttackDiscoveriesMock().slice(0, splitIndex);
|
||||
const secondResponse = getRawAttackDiscoveriesMock().slice(splitIndex);
|
||||
|
||||
const mockLlmWithResponse = new FakeLLM({
|
||||
response: secondResponse,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(secondResponse);
|
||||
|
||||
const generateNode = getGenerateNode({
|
||||
llm: mockLlmWithResponse,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
...initialGraphState,
|
||||
combinedGenerations: firstResponse,
|
||||
generationAttempts: 1,
|
||||
generations: [firstResponse],
|
||||
};
|
||||
|
||||
const state = await generateNode(withPreviousGenerations);
|
||||
|
||||
expect(state).toEqual({
|
||||
...withPreviousGenerations,
|
||||
attackDiscoveries: null,
|
||||
combinedGenerations: firstResponse.concat(secondResponse),
|
||||
errors: [],
|
||||
generationAttempts: 2,
|
||||
generations: [firstResponse, secondResponse],
|
||||
unrefinedResults: getParsedAttackDiscoveriesMock(attackDiscoveryTimestamp), // <-- generated from the combined response
|
||||
});
|
||||
});
|
||||
|
||||
it('skips the refinements step if the max number of retries has already been reached', async () => {
|
||||
// split the response into two parts to simulate a valid response
|
||||
const splitIndex = 100; // arbitrary index
|
||||
const firstResponse = getRawAttackDiscoveriesMock().slice(0, splitIndex);
|
||||
const secondResponse = getRawAttackDiscoveriesMock().slice(splitIndex);
|
||||
|
||||
const mockLlmWithResponse = new FakeLLM({
|
||||
response: secondResponse,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(secondResponse);
|
||||
|
||||
const generateNode = getGenerateNode({
|
||||
llm: mockLlmWithResponse,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
...initialGraphState,
|
||||
combinedGenerations: firstResponse,
|
||||
generationAttempts: 9,
|
||||
generations: [firstResponse],
|
||||
};
|
||||
|
||||
const state = await generateNode(withPreviousGenerations);
|
||||
|
||||
expect(state).toEqual({
|
||||
...withPreviousGenerations,
|
||||
attackDiscoveries: getParsedAttackDiscoveriesMock(attackDiscoveryTimestamp), // <-- skip the refinement step
|
||||
combinedGenerations: firstResponse.concat(secondResponse),
|
||||
errors: [],
|
||||
generationAttempts: 10,
|
||||
generations: [firstResponse, secondResponse],
|
||||
unrefinedResults: getParsedAttackDiscoveriesMock(attackDiscoveryTimestamp), // <-- generated from the combined response
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -58,10 +58,10 @@ export const getGenerateNode = ({
|
|||
() => `generate node is invoking the chain (${llmType}), attempt ${generationAttempts}`
|
||||
);
|
||||
|
||||
const rawResponse = (await chain.invoke({
|
||||
const rawResponse = await chain.invoke({
|
||||
format_instructions: formatInstructions,
|
||||
query,
|
||||
})) as unknown as string;
|
||||
});
|
||||
|
||||
// LOCAL MUTATION:
|
||||
partialResponse = extractJson(rawResponse); // remove the surrounding ```json```
|
||||
|
@ -86,7 +86,7 @@ export const getGenerateNode = ({
|
|||
generationsAreRepeating({
|
||||
currentGeneration: partialResponse,
|
||||
previousGenerations: generations,
|
||||
sampleLastNGenerations: maxRepeatedGenerations,
|
||||
sampleLastNGenerations: maxRepeatedGenerations - 1,
|
||||
})
|
||||
) {
|
||||
logger?.debug(
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { addTrailingBackticksIfNecessary } from '.';
|
||||
|
||||
describe('addTrailingBackticksIfNecessary', () => {
|
||||
it('adds trailing backticks when necessary', () => {
|
||||
const input = '```json\n{\n "key": "value"\n}';
|
||||
const expected = '```json\n{\n "key": "value"\n}\n```';
|
||||
const result = addTrailingBackticksIfNecessary(input);
|
||||
|
||||
expect(result).toEqual(expected);
|
||||
});
|
||||
|
||||
it('does NOT add trailing backticks when they are already present', () => {
|
||||
const input = '```json\n{\n "key": "value"\n}\n```';
|
||||
const result = addTrailingBackticksIfNecessary(input);
|
||||
|
||||
expect(result).toEqual(input);
|
||||
});
|
||||
|
||||
it("does NOT add trailing backticks when there's no leading JSON wrapper", () => {
|
||||
const input = '{\n "key": "value"\n}';
|
||||
const result = addTrailingBackticksIfNecessary(input);
|
||||
|
||||
expect(result).toEqual(input);
|
||||
});
|
||||
|
||||
it('handles empty string input', () => {
|
||||
const input = '';
|
||||
const result = addTrailingBackticksIfNecessary(input);
|
||||
|
||||
expect(result).toEqual(input);
|
||||
});
|
||||
|
||||
it('handles input without a JSON wrapper, but with trailing backticks', () => {
|
||||
const input = '{\n "key": "value"\n}\n```';
|
||||
const result = addTrailingBackticksIfNecessary(input);
|
||||
|
||||
expect(result).toEqual(input);
|
||||
});
|
||||
});
|
|
@ -8,6 +8,24 @@
|
|||
import { extractJson } from '.';
|
||||
|
||||
describe('extractJson', () => {
|
||||
it('returns an empty string if input is undefined', () => {
|
||||
const input = undefined;
|
||||
|
||||
expect(extractJson(input)).toBe('');
|
||||
});
|
||||
|
||||
it('returns an empty string if input an array', () => {
|
||||
const input = ['some', 'array'];
|
||||
|
||||
expect(extractJson(input)).toBe('');
|
||||
});
|
||||
|
||||
it('returns an empty string if input is an object', () => {
|
||||
const input = {};
|
||||
|
||||
expect(extractJson(input)).toBe('');
|
||||
});
|
||||
|
||||
it('returns the JSON text surrounded by ```json and ``` with no whitespace or additional text', () => {
|
||||
const input = '```json{"key": "value"}```';
|
||||
|
||||
|
|
|
@ -5,7 +5,11 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
export const extractJson = (input: string): string => {
|
||||
export const extractJson = (input: unknown): string => {
|
||||
if (typeof input !== 'string') {
|
||||
return '';
|
||||
}
|
||||
|
||||
const regex = /```json\s*([\s\S]*?)(?:\s*```|$)/;
|
||||
const match = input.match(regex);
|
||||
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getCombined } from '.';
|
||||
|
||||
describe('getCombined', () => {
|
||||
it('combines two strings correctly', () => {
|
||||
const combinedGenerations = 'generation1';
|
||||
const partialResponse = 'response1';
|
||||
const expected = 'generation1response1';
|
||||
const result = getCombined({ combinedGenerations, partialResponse });
|
||||
|
||||
expect(result).toEqual(expected);
|
||||
});
|
||||
|
||||
it('handles empty combinedGenerations', () => {
|
||||
const combinedGenerations = '';
|
||||
const partialResponse = 'response1';
|
||||
const expected = 'response1';
|
||||
const result = getCombined({ combinedGenerations, partialResponse });
|
||||
|
||||
expect(result).toEqual(expected);
|
||||
});
|
||||
|
||||
it('handles an empty partialResponse', () => {
|
||||
const combinedGenerations = 'generation1';
|
||||
const partialResponse = '';
|
||||
const expected = 'generation1';
|
||||
const result = getCombined({ combinedGenerations, partialResponse });
|
||||
|
||||
expect(result).toEqual(expected);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,22 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getContinuePrompt } from '.';
|
||||
|
||||
describe('getContinuePrompt', () => {
|
||||
it('returns the expected prompt string', () => {
|
||||
const expectedPrompt = `Continue exactly where you left off in the JSON output below, generating only the additional JSON output when it's required to complete your work. The additional JSON output MUST ALWAYS follow these rules:
|
||||
1) it MUST conform to the schema above, because it will be checked against the JSON schema
|
||||
2) it MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds), because it will be parsed as JSON
|
||||
3) it MUST NOT repeat any the previous output, because that would prevent partial results from being combined
|
||||
4) it MUST NOT restart from the beginning, because that would prevent partial results from being combined
|
||||
5) it MUST NOT be prefixed or suffixed with additional text outside of the JSON, because that would prevent it from being combined and parsed as JSON:
|
||||
`;
|
||||
|
||||
expect(getContinuePrompt()).toBe(expectedPrompt);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,16 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getDefaultAttackDiscoveryPrompt } from '.';
|
||||
|
||||
describe('getDefaultAttackDiscoveryPrompt', () => {
|
||||
it('returns the default attack discovery prompt', () => {
|
||||
expect(getDefaultAttackDiscoveryPrompt()).toEqual(
|
||||
"You are a cyber security analyst tasked with analyzing security events from Elastic Security to identify and report on potential cyber attacks or progressions. Your report should focus on high-risk incidents that could severely impact the organization, rather than isolated alerts. Present your findings in a way that can be easily understood by anyone, regardless of their technical expertise, as if you were briefing the CISO. Break down your response into sections based on timing, hosts, and users involved. When correlating alerts, use kibana.alert.original_time when it's available, otherwise use @timestamp. Include appropriate context about the affected hosts and users. Describe how the attack progression might have occurred and, if feasible, attribute it to known threat groups. Prioritize high and critical alerts, but include lower-severity alerts if desired. In the description field, provide as much detail as possible, in a bulleted list explaining any attack progressions. Accuracy is of utmost importance. You MUST escape all JSON special characters (i.e. backslashes, double quotes, newlines, tabs, carriage returns, backspaces, and form feeds)."
|
||||
);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,118 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { Logger } from '@kbn/core/server';
|
||||
|
||||
import { parseCombinedOrThrow } from '.';
|
||||
import { getRawAttackDiscoveriesMock } from '../../../../../../../__mocks__/raw_attack_discoveries';
|
||||
|
||||
describe('parseCombinedOrThrow', () => {
|
||||
const mockLogger: Logger = {
|
||||
debug: jest.fn(),
|
||||
} as unknown as Logger;
|
||||
|
||||
const nodeName = 'testNodeName';
|
||||
const llmType = 'testLlmType';
|
||||
|
||||
const validCombinedResponse = getRawAttackDiscoveriesMock();
|
||||
|
||||
const invalidCombinedResponse = 'invalid json';
|
||||
|
||||
const defaultArgs = {
|
||||
combinedResponse: validCombinedResponse,
|
||||
generationAttempts: 0,
|
||||
nodeName,
|
||||
llmType,
|
||||
logger: mockLogger,
|
||||
};
|
||||
|
||||
it('returns an Attack discovery for each insight in a valid combined response', () => {
|
||||
const discoveries = parseCombinedOrThrow({
|
||||
...defaultArgs,
|
||||
});
|
||||
|
||||
expect(discoveries).toHaveLength(5);
|
||||
});
|
||||
|
||||
it('adds a timestamp to all discoveries in a valid response', () => {
|
||||
const discoveries = parseCombinedOrThrow({
|
||||
...defaultArgs,
|
||||
});
|
||||
|
||||
expect(discoveries.every((discovery) => discovery.timestamp != null)).toBe(true);
|
||||
});
|
||||
|
||||
it('adds trailing backticks to the combined response if necessary', () => {
|
||||
const withLeadingJson = '```json\n'.concat(validCombinedResponse);
|
||||
|
||||
const discoveries = parseCombinedOrThrow({
|
||||
...defaultArgs,
|
||||
combinedResponse: withLeadingJson,
|
||||
});
|
||||
|
||||
expect(discoveries).toHaveLength(5);
|
||||
});
|
||||
|
||||
it('logs the parsing step', () => {
|
||||
const generationAttempts = 0;
|
||||
|
||||
parseCombinedOrThrow({
|
||||
...defaultArgs,
|
||||
generationAttempts,
|
||||
});
|
||||
|
||||
expect((mockLogger.debug as jest.Mock).mock.calls[0][0]()).toBe(
|
||||
`${nodeName} node is parsing extractedJson (${llmType}) from attempt ${generationAttempts}`
|
||||
);
|
||||
});
|
||||
|
||||
it('logs the validation step', () => {
|
||||
const generationAttempts = 0;
|
||||
|
||||
parseCombinedOrThrow({
|
||||
...defaultArgs,
|
||||
generationAttempts,
|
||||
});
|
||||
|
||||
expect((mockLogger.debug as jest.Mock).mock.calls[1][0]()).toBe(
|
||||
`${nodeName} node is validating combined response (${llmType}) from attempt ${generationAttempts}`
|
||||
);
|
||||
});
|
||||
|
||||
it('logs the successful validation step', () => {
|
||||
const generationAttempts = 0;
|
||||
|
||||
parseCombinedOrThrow({
|
||||
...defaultArgs,
|
||||
generationAttempts,
|
||||
});
|
||||
|
||||
expect((mockLogger.debug as jest.Mock).mock.calls[2][0]()).toBe(
|
||||
`${nodeName} node successfully validated Attack discoveries response (${llmType}) from attempt ${generationAttempts}`
|
||||
);
|
||||
});
|
||||
|
||||
it('throws the expected error when JSON parsing fails', () => {
|
||||
expect(() =>
|
||||
parseCombinedOrThrow({
|
||||
...defaultArgs,
|
||||
combinedResponse: invalidCombinedResponse,
|
||||
})
|
||||
).toThrowError('Unexpected token \'i\', "invalid json" is not valid JSON');
|
||||
});
|
||||
|
||||
it('throws the expected error when JSON validation fails', () => {
|
||||
const invalidJson = '{ "insights": "not an array" }';
|
||||
|
||||
expect(() =>
|
||||
parseCombinedOrThrow({
|
||||
...defaultArgs,
|
||||
combinedResponse: invalidJson,
|
||||
})
|
||||
).toThrowError('Expected array, received string');
|
||||
});
|
||||
});
|
|
@ -0,0 +1,82 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { discardPreviousRefinements } from '.';
|
||||
import { mockAttackDiscoveries } from '../../../../../../evaluation/__mocks__/mock_attack_discoveries';
|
||||
import { GraphState } from '../../../../types';
|
||||
|
||||
const initialState: GraphState = {
|
||||
anonymizedAlerts: [],
|
||||
attackDiscoveries: null,
|
||||
attackDiscoveryPrompt: 'attackDiscoveryPrompt',
|
||||
combinedGenerations: 'generation1generation2',
|
||||
combinedRefinements: 'refinement1', // <-- existing refinements
|
||||
errors: [],
|
||||
generationAttempts: 3,
|
||||
generations: ['generation1', 'generation2'],
|
||||
hallucinationFailures: 0,
|
||||
maxGenerationAttempts: 10,
|
||||
maxHallucinationFailures: 5,
|
||||
maxRepeatedGenerations: 3,
|
||||
refinements: ['refinement1'],
|
||||
refinePrompt: 'refinePrompt',
|
||||
replacements: {},
|
||||
unrefinedResults: [...mockAttackDiscoveries],
|
||||
};
|
||||
|
||||
describe('discardPreviousRefinements', () => {
|
||||
describe('common state updates', () => {
|
||||
let result: GraphState;
|
||||
|
||||
beforeEach(() => {
|
||||
result = discardPreviousRefinements({
|
||||
generationAttempts: initialState.generationAttempts,
|
||||
hallucinationFailures: initialState.hallucinationFailures,
|
||||
isHallucinationDetected: true,
|
||||
state: initialState,
|
||||
});
|
||||
});
|
||||
|
||||
it('resets the combined refinements', () => {
|
||||
expect(result.combinedRefinements).toBe('');
|
||||
});
|
||||
|
||||
it('increments the generation attempts', () => {
|
||||
expect(result.generationAttempts).toBe(initialState.generationAttempts + 1);
|
||||
});
|
||||
|
||||
it('resets the refinements', () => {
|
||||
expect(result.refinements).toEqual([]);
|
||||
});
|
||||
|
||||
it('increments the hallucination failures when hallucinations are detected', () => {
|
||||
expect(result.hallucinationFailures).toBe(initialState.hallucinationFailures + 1);
|
||||
});
|
||||
});
|
||||
|
||||
it('increments the hallucination failures when hallucinations are detected', () => {
|
||||
const result = discardPreviousRefinements({
|
||||
generationAttempts: initialState.generationAttempts,
|
||||
hallucinationFailures: initialState.hallucinationFailures,
|
||||
isHallucinationDetected: true, // <-- hallucinations detected
|
||||
state: initialState,
|
||||
});
|
||||
|
||||
expect(result.hallucinationFailures).toBe(initialState.hallucinationFailures + 1);
|
||||
});
|
||||
|
||||
it('does NOT increment the hallucination failures when hallucinations are NOT detected', () => {
|
||||
const result = discardPreviousRefinements({
|
||||
generationAttempts: initialState.generationAttempts,
|
||||
hallucinationFailures: initialState.hallucinationFailures,
|
||||
isHallucinationDetected: false, // <-- no hallucinations detected
|
||||
state: initialState,
|
||||
});
|
||||
|
||||
expect(result.hallucinationFailures).toBe(initialState.hallucinationFailures);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getCombinedRefinePrompt } from '.';
|
||||
import { mockAttackDiscoveries } from '../../../../../../evaluation/__mocks__/mock_attack_discoveries';
|
||||
import { getContinuePrompt } from '../../../helpers/get_continue_prompt';
|
||||
|
||||
describe('getCombinedRefinePrompt', () => {
|
||||
it('returns the base query when combinedRefinements is empty', () => {
|
||||
const result = getCombinedRefinePrompt({
|
||||
attackDiscoveryPrompt: 'Initial query',
|
||||
combinedRefinements: '',
|
||||
refinePrompt: 'Refine prompt',
|
||||
unrefinedResults: [...mockAttackDiscoveries],
|
||||
});
|
||||
|
||||
expect(result).toEqual(`Initial query
|
||||
|
||||
Refine prompt
|
||||
|
||||
"""
|
||||
${JSON.stringify(mockAttackDiscoveries, null, 2)}
|
||||
"""
|
||||
|
||||
`);
|
||||
});
|
||||
|
||||
it('returns the combined prompt when combinedRefinements is not empty', () => {
|
||||
const result = getCombinedRefinePrompt({
|
||||
attackDiscoveryPrompt: 'Initial query',
|
||||
combinedRefinements: 'Combined refinements',
|
||||
refinePrompt: 'Refine prompt',
|
||||
unrefinedResults: [...mockAttackDiscoveries],
|
||||
});
|
||||
|
||||
expect(result).toEqual(`Initial query
|
||||
|
||||
Refine prompt
|
||||
|
||||
"""
|
||||
${JSON.stringify(mockAttackDiscoveries, null, 2)}
|
||||
"""
|
||||
|
||||
|
||||
|
||||
${getContinuePrompt()}
|
||||
|
||||
"""
|
||||
Combined refinements
|
||||
"""
|
||||
|
||||
`);
|
||||
});
|
||||
|
||||
it('handles null unrefinedResults', () => {
|
||||
const result = getCombinedRefinePrompt({
|
||||
attackDiscoveryPrompt: 'Initial query',
|
||||
combinedRefinements: '',
|
||||
refinePrompt: 'Refine prompt',
|
||||
unrefinedResults: null,
|
||||
});
|
||||
|
||||
expect(result).toEqual(`Initial query
|
||||
|
||||
Refine prompt
|
||||
|
||||
"""
|
||||
null
|
||||
"""
|
||||
|
||||
`);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,19 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getDefaultRefinePrompt } from '.';
|
||||
|
||||
describe('getDefaultRefinePrompt', () => {
|
||||
it('returns the default refine prompt string', () => {
|
||||
const result = getDefaultRefinePrompt();
|
||||
|
||||
expect(result)
|
||||
.toEqual(`You previously generated the following insights, but sometimes they represent the same attack.
|
||||
|
||||
Combine the insights below, when they represent the same attack; leave any insights that are not combined unchanged:`);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getUseUnrefinedResults } from '.';
|
||||
|
||||
describe('getUseUnrefinedResults', () => {
|
||||
it('returns true if both maxHallucinationFailuresReached and maxRetriesReached are true', () => {
|
||||
const result = getUseUnrefinedResults({
|
||||
maxHallucinationFailuresReached: true,
|
||||
maxRetriesReached: true,
|
||||
});
|
||||
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('returns true if maxHallucinationFailuresReached is true and maxRetriesReached is false', () => {
|
||||
const result = getUseUnrefinedResults({
|
||||
maxHallucinationFailuresReached: true,
|
||||
maxRetriesReached: false,
|
||||
});
|
||||
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('returns true if maxHallucinationFailuresReached is false and maxRetriesReached is true', () => {
|
||||
const result = getUseUnrefinedResults({
|
||||
maxHallucinationFailuresReached: false,
|
||||
maxRetriesReached: true,
|
||||
});
|
||||
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false if both maxHallucinationFailuresReached and maxRetriesReached are false', () => {
|
||||
const result = getUseUnrefinedResults({
|
||||
maxHallucinationFailuresReached: false,
|
||||
maxRetriesReached: false,
|
||||
});
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,342 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { AttackDiscovery } from '@kbn/elastic-assistant-common';
|
||||
import type { ActionsClientLlm } from '@kbn/langchain/server';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import { FakeLLM } from '@langchain/core/utils/testing';
|
||||
|
||||
import { getRefineNode } from '.';
|
||||
import {
|
||||
mockAnonymizedAlerts,
|
||||
mockAnonymizedAlertsReplacements,
|
||||
} from '../../../../evaluation/__mocks__/mock_anonymized_alerts';
|
||||
import { getChainWithFormatInstructions } from '../helpers/get_chain_with_format_instructions';
|
||||
import { getDefaultAttackDiscoveryPrompt } from '../helpers/get_default_attack_discovery_prompt';
|
||||
import { getDefaultRefinePrompt } from './helpers/get_default_refine_prompt';
|
||||
import { GraphState } from '../../types';
|
||||
import {
|
||||
getParsedAttackDiscoveriesMock,
|
||||
getRawAttackDiscoveriesMock,
|
||||
} from '../../../../../../__mocks__/raw_attack_discoveries';
|
||||
|
||||
const attackDiscoveryTimestamp = '2024-10-11T17:55:59.702Z';
|
||||
|
||||
export const mockUnrefinedAttackDiscoveries: AttackDiscovery[] = [
|
||||
{
|
||||
title: 'unrefinedTitle1',
|
||||
alertIds: ['unrefinedAlertId1', 'unrefinedAlertId2', 'unrefinedAlertId3'],
|
||||
timestamp: '2024-10-10T22:59:52.749Z',
|
||||
detailsMarkdown: 'unrefinedDetailsMarkdown1',
|
||||
summaryMarkdown: 'unrefinedSummaryMarkdown1 - entity A',
|
||||
mitreAttackTactics: ['Input Capture'],
|
||||
entitySummaryMarkdown: 'entitySummaryMarkdown1',
|
||||
},
|
||||
{
|
||||
title: 'unrefinedTitle2',
|
||||
alertIds: ['unrefinedAlertId3', 'unrefinedAlertId4', 'unrefinedAlertId5'],
|
||||
timestamp: '2024-10-10T22:59:52.749Z',
|
||||
detailsMarkdown: 'unrefinedDetailsMarkdown2',
|
||||
summaryMarkdown: 'unrefinedSummaryMarkdown2 - also entity A',
|
||||
mitreAttackTactics: ['Credential Access'],
|
||||
entitySummaryMarkdown: 'entitySummaryMarkdown2',
|
||||
},
|
||||
];
|
||||
|
||||
jest.mock('../helpers/get_chain_with_format_instructions', () => {
|
||||
const mockInvoke = jest.fn().mockResolvedValue('');
|
||||
|
||||
return {
|
||||
getChainWithFormatInstructions: jest.fn().mockReturnValue({
|
||||
chain: {
|
||||
invoke: mockInvoke,
|
||||
},
|
||||
formatInstructions: ['mock format instructions'],
|
||||
llmType: 'openai',
|
||||
mockInvoke, // <-- added for testing
|
||||
}),
|
||||
};
|
||||
});
|
||||
|
||||
const mockLogger = loggerMock.create();
|
||||
let mockLlm: ActionsClientLlm;
|
||||
|
||||
const initialGraphState: GraphState = {
|
||||
attackDiscoveries: null,
|
||||
attackDiscoveryPrompt: getDefaultAttackDiscoveryPrompt(),
|
||||
anonymizedAlerts: [...mockAnonymizedAlerts],
|
||||
combinedGenerations: 'gen1',
|
||||
combinedRefinements: '',
|
||||
errors: [],
|
||||
generationAttempts: 1,
|
||||
generations: ['gen1'],
|
||||
hallucinationFailures: 0,
|
||||
maxGenerationAttempts: 10,
|
||||
maxHallucinationFailures: 5,
|
||||
maxRepeatedGenerations: 3,
|
||||
refinements: [],
|
||||
refinePrompt: getDefaultRefinePrompt(),
|
||||
replacements: {
|
||||
...mockAnonymizedAlertsReplacements,
|
||||
},
|
||||
unrefinedResults: [...mockUnrefinedAttackDiscoveries],
|
||||
};
|
||||
|
||||
describe('getRefineNode', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
jest.useFakeTimers();
|
||||
jest.setSystemTime(new Date(attackDiscoveryTimestamp));
|
||||
|
||||
mockLlm = new FakeLLM({
|
||||
response: '',
|
||||
}) as unknown as ActionsClientLlm;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.useRealTimers();
|
||||
});
|
||||
|
||||
it('returns a function', () => {
|
||||
const refineNode = getRefineNode({
|
||||
llm: mockLlm,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
expect(typeof refineNode).toBe('function');
|
||||
});
|
||||
|
||||
it('invokes the chain with the unrefinedResults from state and format instructions', async () => {
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlm).chain.invoke as jest.Mock;
|
||||
|
||||
const refineNode = getRefineNode({
|
||||
llm: mockLlm,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
await refineNode(initialGraphState);
|
||||
|
||||
expect(mockInvoke).toHaveBeenCalledWith({
|
||||
format_instructions: ['mock format instructions'],
|
||||
query: `${initialGraphState.attackDiscoveryPrompt}
|
||||
|
||||
${getDefaultRefinePrompt()}
|
||||
|
||||
\"\"\"
|
||||
${JSON.stringify(initialGraphState.unrefinedResults, null, 2)}
|
||||
\"\"\"
|
||||
|
||||
`,
|
||||
});
|
||||
});
|
||||
|
||||
it('removes the surrounding json from the response', async () => {
|
||||
const response =
|
||||
'You asked for some JSON, here it is:\n```json\n{"key": "value"}\n```\nI hope that works for you.';
|
||||
|
||||
const mockLlmWithResponse = new FakeLLM({ response }) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(response);
|
||||
|
||||
const refineNode = getRefineNode({
|
||||
llm: mockLlm,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
const state = await refineNode(initialGraphState);
|
||||
|
||||
expect(state).toEqual({
|
||||
...initialGraphState,
|
||||
combinedRefinements: '{"key": "value"}',
|
||||
errors: [
|
||||
'refine node is unable to parse (fake) response from attempt 1; (this may be an incomplete response from the model): [\n {\n "code": "invalid_type",\n "expected": "array",\n "received": "undefined",\n "path": [\n "insights"\n ],\n "message": "Required"\n }\n]',
|
||||
],
|
||||
generationAttempts: 2,
|
||||
refinements: ['{"key": "value"}'],
|
||||
});
|
||||
});
|
||||
|
||||
it('handles hallucinations', async () => {
|
||||
const hallucinatedResponse =
|
||||
'tactics like **Credential Access**, **Command and Control**, and **Persistence**.",\n "entitySummaryMarkdown": "Malware detected on host **{{ host.name hostNameValue }}**';
|
||||
|
||||
const mockLlmWithHallucination = new FakeLLM({
|
||||
response: hallucinatedResponse,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithHallucination).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(hallucinatedResponse);
|
||||
|
||||
const refineNode = getRefineNode({
|
||||
llm: mockLlmWithHallucination,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
...initialGraphState,
|
||||
combinedRefinements: '{"key": "value"}',
|
||||
refinements: ['{"key": "value"}'],
|
||||
};
|
||||
|
||||
const state = await refineNode(withPreviousGenerations);
|
||||
|
||||
expect(state).toEqual({
|
||||
...withPreviousGenerations,
|
||||
combinedRefinements: '', // <-- reset
|
||||
generationAttempts: 2, // <-- incremented
|
||||
refinements: [], // <-- reset
|
||||
hallucinationFailures: 1, // <-- incremented
|
||||
});
|
||||
});
|
||||
|
||||
it('discards previous refinements and starts over when the maxRepeatedGenerations limit is reached', async () => {
|
||||
const repeatedResponse = '{"key": "value"}';
|
||||
|
||||
const mockLlmWithRepeatedGenerations = new FakeLLM({
|
||||
response: repeatedResponse,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithRepeatedGenerations).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(repeatedResponse);
|
||||
|
||||
const refineNode = getRefineNode({
|
||||
llm: mockLlmWithRepeatedGenerations,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
...initialGraphState,
|
||||
combinedRefinements: '{"key": "value"}{"key": "value"}',
|
||||
generationAttempts: 3,
|
||||
refinements: ['{"key": "value"}', '{"key": "value"}'],
|
||||
};
|
||||
|
||||
const state = await refineNode(withPreviousGenerations);
|
||||
|
||||
expect(state).toEqual({
|
||||
...withPreviousGenerations,
|
||||
combinedRefinements: '',
|
||||
generationAttempts: 4, // <-- incremented
|
||||
refinements: [],
|
||||
});
|
||||
});
|
||||
|
||||
it('combines the response with the previous refinements', async () => {
|
||||
const response = 'refine1';
|
||||
|
||||
const mockLlmWithResponse = new FakeLLM({
|
||||
response,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(response);
|
||||
|
||||
const refineNode = getRefineNode({
|
||||
llm: mockLlmWithResponse,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
...initialGraphState,
|
||||
combinedRefinements: 'refine0',
|
||||
generationAttempts: 2,
|
||||
refinements: ['refine0'],
|
||||
};
|
||||
|
||||
const state = await refineNode(withPreviousGenerations);
|
||||
|
||||
expect(state).toEqual({
|
||||
...withPreviousGenerations,
|
||||
combinedRefinements: 'refine0refine1',
|
||||
errors: [
|
||||
'refine node is unable to parse (fake) response from attempt 2; (this may be an incomplete response from the model): SyntaxError: Unexpected token \'r\', "refine0refine1" is not valid JSON',
|
||||
],
|
||||
generationAttempts: 3,
|
||||
refinements: ['refine0', 'refine1'],
|
||||
});
|
||||
});
|
||||
|
||||
it('returns refined results when combined responses pass validation', async () => {
|
||||
// split the response into two parts to simulate a valid response
|
||||
const splitIndex = 100; // arbitrary index
|
||||
const firstResponse = getRawAttackDiscoveriesMock().slice(0, splitIndex);
|
||||
const secondResponse = getRawAttackDiscoveriesMock().slice(splitIndex);
|
||||
|
||||
const mockLlmWithResponse = new FakeLLM({
|
||||
response: secondResponse,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(secondResponse);
|
||||
|
||||
const refineNode = getRefineNode({
|
||||
llm: mockLlmWithResponse,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
...initialGraphState,
|
||||
combinedRefinements: firstResponse,
|
||||
generationAttempts: 2,
|
||||
refinements: [firstResponse],
|
||||
};
|
||||
|
||||
const state = await refineNode(withPreviousGenerations);
|
||||
|
||||
expect(state).toEqual({
|
||||
...withPreviousGenerations,
|
||||
attackDiscoveries: getParsedAttackDiscoveriesMock(attackDiscoveryTimestamp),
|
||||
combinedRefinements: firstResponse.concat(secondResponse),
|
||||
generationAttempts: 3,
|
||||
refinements: [firstResponse, secondResponse],
|
||||
});
|
||||
});
|
||||
|
||||
it('uses the unrefined results when the max number of retries has already been reached', async () => {
|
||||
const response = 'this will not pass JSON parsing';
|
||||
|
||||
const mockLlmWithResponse = new FakeLLM({
|
||||
response,
|
||||
}) as unknown as ActionsClientLlm;
|
||||
const mockInvoke = getChainWithFormatInstructions(mockLlmWithResponse).chain
|
||||
.invoke as jest.Mock;
|
||||
|
||||
mockInvoke.mockResolvedValue(response);
|
||||
|
||||
const refineNode = getRefineNode({
|
||||
llm: mockLlmWithResponse,
|
||||
logger: mockLogger,
|
||||
});
|
||||
|
||||
const withPreviousGenerations = {
|
||||
...initialGraphState,
|
||||
combinedRefinements: 'refine1',
|
||||
generationAttempts: 9,
|
||||
refinements: ['refine1'],
|
||||
};
|
||||
|
||||
const state = await refineNode(withPreviousGenerations);
|
||||
|
||||
expect(state).toEqual({
|
||||
...withPreviousGenerations,
|
||||
attackDiscoveries: state.unrefinedResults, // <-- the unrefined results are returned
|
||||
combinedRefinements: 'refine1this will not pass JSON parsing',
|
||||
errors: [
|
||||
'refine node is unable to parse (fake) response from attempt 9; (this may be an incomplete response from the model): SyntaxError: Unexpected token \'r\', "refine1thi"... is not valid JSON',
|
||||
],
|
||||
generationAttempts: 10,
|
||||
refinements: ['refine1', response],
|
||||
});
|
||||
});
|
||||
});
|
|
@ -89,7 +89,7 @@ export const getRefineNode = ({
|
|||
generationsAreRepeating({
|
||||
currentGeneration: partialResponse,
|
||||
previousGenerations: refinements,
|
||||
sampleLastNGenerations: maxRepeatedGenerations,
|
||||
sampleLastNGenerations: maxRepeatedGenerations - 1,
|
||||
})
|
||||
) {
|
||||
logger?.debug(
|
||||
|
|
|
@ -0,0 +1,104 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type { ElasticsearchClient } from '@kbn/core/server';
|
||||
import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks';
|
||||
|
||||
import { AnonymizedAlertsRetriever } from '.';
|
||||
import { getMockAnonymizationFieldResponse } from '../../../../../evaluation/__mocks__/mock_anonymization_fields';
|
||||
import { getAnonymizedAlerts } from '../helpers/get_anonymized_alerts';
|
||||
|
||||
const anonymizationFields = getMockAnonymizationFieldResponse();
|
||||
|
||||
const rawAlerts = [
|
||||
'@timestamp,2024-11-05T15:42:48.034Z\n_id,07d86d116ff754f4aa57c00e23a5273c2efbc9450416823ebd1d7b343b42d11a\nevent.category,malware,intrusion_detection,process\nevent.dataset,endpoint.alerts\nevent.module,endpoint\nevent.outcome,success\nfile.hash.sha256,2c63ba2b1a5131b80e567b7a1a93997a2de07ea20d0a8f5149701c67b832c097\nfile.name,My Go Application.app\nfile.path,/private/var/folders/_b/rmcpc65j6nv11ygrs50ctcjr0000gn/T/AppTranslocation/6D63F08A-011C-4511-8556-EAEF9AFD6340/d/Setup.app/Contents/MacOS/My Go Application.app\nhost.name,d26e9abd-6cbb-4620-a802-a22b97845d5c\nhost.os.name,macOS\nhost.os.version,13.4\nkibana.alert.original_time,2023-06-19T00:28:06.888Z\nkibana.alert.risk_score,99\nkibana.alert.rule.description,Generates a detection alert each time an Elastic Endpoint Security alert is received. Enabling this rule allows you to immediately begin investigating your Endpoint alerts.\nkibana.alert.rule.name,Malware Detection Alert\nkibana.alert.severity,critical\nkibana.alert.workflow_status,open\nmessage,Malware Detection Alert\nprocess.args,xpcproxy,application.Appify by Machine Box.My Go Application.20.23\nprocess.code_signature.exists,true\nprocess.code_signature.signing_id,a.out\nprocess.code_signature.status,code failed to satisfy specified code requirement(s)\nprocess.code_signature.subject_name,\nprocess.code_signature.trusted,false\nprocess.command_line,xpcproxy application.Appify by Machine Box.My Go Application.20.23\nprocess.executable,/private/var/folders/_b/rmcpc65j6nv11ygrs50ctcjr0000gn/T/AppTranslocation/6D63F08A-011C-4511-8556-EAEF9AFD6340/d/Setup.app/Contents/MacOS/My Go Application.app\nprocess.hash.md5,e62bdd3eaf2be436fca2e67b7eede603\nprocess.hash.sha1,58a3bddbc7c45193ecbefa22ad0496b60a29dff2\nprocess.hash.sha256,2c63ba2b1a5131b80e567b7a1a93997a2de07ea20d0a8f5149701c67b832c097\nprocess.name,My Go Application.app\nprocess.parent.args,/sbin/launchd\nprocess.parent.args_count,1\nprocess.parent.code_signature.exists,true\nprocess.parent.code_signature.status,No error.\nprocess.parent.code_signature.subject_name,Software Signing\nprocess.parent.code_signature.trusted,true\nprocess.parent.command_line,/sbin/launchd\nprocess.parent.executable,/sbin/launchd\nprocess.parent.name,launchd\nprocess.pid,1200\nuser.name,81c3db40-f3da-4c6a-b3c8-48c776148102',
|
||||
'@timestamp,2024-11-05T15:42:48.033Z\n_id,f2d2d8bd15402e8efff81d48b70ef8cb890d5502576fb92365ee2328f5fcb123\nevent.category,malware,intrusion_detection,process\nevent.dataset,endpoint.alerts\nevent.module,endpoint\nevent.outcome,success\nfile.hash.sha256,2c63ba2b1a5131b80e567b7a1a93997a2de07ea20d0a8f5149701c67b832c097\nfile.name,My Go Application.app\nfile.path,/private/var/folders/_b/rmcpc65j6nv11ygrs50ctcjr0000gn/T/AppTranslocation/3C4D44B9-4838-4613-BACC-BD00A9CE4025/d/Setup.app/Contents/MacOS/My Go Application.app\nhost.name,d26e9abd-6cbb-4620-a802-a22b97845d5c\nhost.os.name,macOS\nhost.os.version,13.4\nkibana.alert.original_time,2023-06-19T00:27:47.362Z\nkibana.alert.risk_score,99\nkibana.alert.rule.description,Generates a detection alert each time an Elastic Endpoint Security alert is received. Enabling this rule allows you to immediately begin investigating your Endpoint alerts.\nkibana.alert.rule.name,Malware Detection Alert\nkibana.alert.severity,critical\nkibana.alert.workflow_status,open\nmessage,Malware Detection Alert\nprocess.args,xpcproxy,application.Appify by Machine Box.My Go Application.20.23\nprocess.code_signature.exists,true\nprocess.code_signature.signing_id,a.out\nprocess.code_signature.status,code failed to satisfy specified code requirement(s)\nprocess.code_signature.subject_name,\nprocess.code_signature.trusted,false\nprocess.command_line,xpcproxy application.Appify by Machine Box.My Go Application.20.23\nprocess.executable,/private/var/folders/_b/rmcpc65j6nv11ygrs50ctcjr0000gn/T/AppTranslocation/3C4D44B9-4838-4613-BACC-BD00A9CE4025/d/Setup.app/Contents/MacOS/My Go Application.app\nprocess.hash.md5,e62bdd3eaf2be436fca2e67b7eede603\nprocess.hash.sha1,58a3bddbc7c45193ecbefa22ad0496b60a29dff2\nprocess.hash.sha256,2c63ba2b1a5131b80e567b7a1a93997a2de07ea20d0a8f5149701c67b832c097\nprocess.name,My Go Application.app\nprocess.parent.args,/sbin/launchd\nprocess.parent.args_count,1\nprocess.parent.code_signature.exists,true\nprocess.parent.code_signature.status,No error.\nprocess.parent.code_signature.subject_name,Software Signing\nprocess.parent.code_signature.trusted,true\nprocess.parent.command_line,/sbin/launchd\nprocess.parent.executable,/sbin/launchd\nprocess.parent.name,launchd\nprocess.pid,1169\nuser.name,81c3db40-f3da-4c6a-b3c8-48c776148102',
|
||||
];
|
||||
|
||||
jest.mock('../helpers/get_anonymized_alerts', () => ({
|
||||
getAnonymizedAlerts: jest.fn(),
|
||||
}));
|
||||
|
||||
describe('AnonymizedAlertsRetriever', () => {
|
||||
let esClient: ElasticsearchClient;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
|
||||
esClient = elasticsearchClientMock.createScopedClusterClient().asCurrentUser;
|
||||
|
||||
(getAnonymizedAlerts as jest.Mock).mockResolvedValue([...rawAlerts]);
|
||||
});
|
||||
|
||||
it('returns the expected pageContent and metadata', async () => {
|
||||
const retriever = new AnonymizedAlertsRetriever({
|
||||
alertsIndexPattern: 'test-pattern',
|
||||
anonymizationFields,
|
||||
esClient,
|
||||
size: 10,
|
||||
});
|
||||
|
||||
const documents = await retriever._getRelevantDocuments('test-query');
|
||||
|
||||
expect(documents).toEqual([
|
||||
{
|
||||
pageContent:
|
||||
'@timestamp,2024-11-05T15:42:48.034Z\n_id,07d86d116ff754f4aa57c00e23a5273c2efbc9450416823ebd1d7b343b42d11a\nevent.category,malware,intrusion_detection,process\nevent.dataset,endpoint.alerts\nevent.module,endpoint\nevent.outcome,success\nfile.hash.sha256,2c63ba2b1a5131b80e567b7a1a93997a2de07ea20d0a8f5149701c67b832c097\nfile.name,My Go Application.app\nfile.path,/private/var/folders/_b/rmcpc65j6nv11ygrs50ctcjr0000gn/T/AppTranslocation/6D63F08A-011C-4511-8556-EAEF9AFD6340/d/Setup.app/Contents/MacOS/My Go Application.app\nhost.name,d26e9abd-6cbb-4620-a802-a22b97845d5c\nhost.os.name,macOS\nhost.os.version,13.4\nkibana.alert.original_time,2023-06-19T00:28:06.888Z\nkibana.alert.risk_score,99\nkibana.alert.rule.description,Generates a detection alert each time an Elastic Endpoint Security alert is received. Enabling this rule allows you to immediately begin investigating your Endpoint alerts.\nkibana.alert.rule.name,Malware Detection Alert\nkibana.alert.severity,critical\nkibana.alert.workflow_status,open\nmessage,Malware Detection Alert\nprocess.args,xpcproxy,application.Appify by Machine Box.My Go Application.20.23\nprocess.code_signature.exists,true\nprocess.code_signature.signing_id,a.out\nprocess.code_signature.status,code failed to satisfy specified code requirement(s)\nprocess.code_signature.subject_name,\nprocess.code_signature.trusted,false\nprocess.command_line,xpcproxy application.Appify by Machine Box.My Go Application.20.23\nprocess.executable,/private/var/folders/_b/rmcpc65j6nv11ygrs50ctcjr0000gn/T/AppTranslocation/6D63F08A-011C-4511-8556-EAEF9AFD6340/d/Setup.app/Contents/MacOS/My Go Application.app\nprocess.hash.md5,e62bdd3eaf2be436fca2e67b7eede603\nprocess.hash.sha1,58a3bddbc7c45193ecbefa22ad0496b60a29dff2\nprocess.hash.sha256,2c63ba2b1a5131b80e567b7a1a93997a2de07ea20d0a8f5149701c67b832c097\nprocess.name,My Go Application.app\nprocess.parent.args,/sbin/launchd\nprocess.parent.args_count,1\nprocess.parent.code_signature.exists,true\nprocess.parent.code_signature.status,No error.\nprocess.parent.code_signature.subject_name,Software Signing\nprocess.parent.code_signature.trusted,true\nprocess.parent.command_line,/sbin/launchd\nprocess.parent.executable,/sbin/launchd\nprocess.parent.name,launchd\nprocess.pid,1200\nuser.name,81c3db40-f3da-4c6a-b3c8-48c776148102',
|
||||
metadata: {},
|
||||
},
|
||||
{
|
||||
pageContent:
|
||||
'@timestamp,2024-11-05T15:42:48.033Z\n_id,f2d2d8bd15402e8efff81d48b70ef8cb890d5502576fb92365ee2328f5fcb123\nevent.category,malware,intrusion_detection,process\nevent.dataset,endpoint.alerts\nevent.module,endpoint\nevent.outcome,success\nfile.hash.sha256,2c63ba2b1a5131b80e567b7a1a93997a2de07ea20d0a8f5149701c67b832c097\nfile.name,My Go Application.app\nfile.path,/private/var/folders/_b/rmcpc65j6nv11ygrs50ctcjr0000gn/T/AppTranslocation/3C4D44B9-4838-4613-BACC-BD00A9CE4025/d/Setup.app/Contents/MacOS/My Go Application.app\nhost.name,d26e9abd-6cbb-4620-a802-a22b97845d5c\nhost.os.name,macOS\nhost.os.version,13.4\nkibana.alert.original_time,2023-06-19T00:27:47.362Z\nkibana.alert.risk_score,99\nkibana.alert.rule.description,Generates a detection alert each time an Elastic Endpoint Security alert is received. Enabling this rule allows you to immediately begin investigating your Endpoint alerts.\nkibana.alert.rule.name,Malware Detection Alert\nkibana.alert.severity,critical\nkibana.alert.workflow_status,open\nmessage,Malware Detection Alert\nprocess.args,xpcproxy,application.Appify by Machine Box.My Go Application.20.23\nprocess.code_signature.exists,true\nprocess.code_signature.signing_id,a.out\nprocess.code_signature.status,code failed to satisfy specified code requirement(s)\nprocess.code_signature.subject_name,\nprocess.code_signature.trusted,false\nprocess.command_line,xpcproxy application.Appify by Machine Box.My Go Application.20.23\nprocess.executable,/private/var/folders/_b/rmcpc65j6nv11ygrs50ctcjr0000gn/T/AppTranslocation/3C4D44B9-4838-4613-BACC-BD00A9CE4025/d/Setup.app/Contents/MacOS/My Go Application.app\nprocess.hash.md5,e62bdd3eaf2be436fca2e67b7eede603\nprocess.hash.sha1,58a3bddbc7c45193ecbefa22ad0496b60a29dff2\nprocess.hash.sha256,2c63ba2b1a5131b80e567b7a1a93997a2de07ea20d0a8f5149701c67b832c097\nprocess.name,My Go Application.app\nprocess.parent.args,/sbin/launchd\nprocess.parent.args_count,1\nprocess.parent.code_signature.exists,true\nprocess.parent.code_signature.status,No error.\nprocess.parent.code_signature.subject_name,Software Signing\nprocess.parent.code_signature.trusted,true\nprocess.parent.command_line,/sbin/launchd\nprocess.parent.executable,/sbin/launchd\nprocess.parent.name,launchd\nprocess.pid,1169\nuser.name,81c3db40-f3da-4c6a-b3c8-48c776148102',
|
||||
metadata: {},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('calls getAnonymizedAlerts with the expected parameters', async () => {
|
||||
const onNewReplacements = jest.fn();
|
||||
const mockReplacements = {
|
||||
replacement1: 'SRVMAC08',
|
||||
replacement2: 'SRVWIN01',
|
||||
replacement3: 'SRVWIN02',
|
||||
};
|
||||
|
||||
const retriever = new AnonymizedAlertsRetriever({
|
||||
alertsIndexPattern: 'test-pattern',
|
||||
anonymizationFields,
|
||||
esClient,
|
||||
onNewReplacements,
|
||||
replacements: mockReplacements,
|
||||
size: 10,
|
||||
});
|
||||
|
||||
await retriever._getRelevantDocuments('test-query');
|
||||
|
||||
expect(getAnonymizedAlerts as jest.Mock).toHaveBeenCalledWith({
|
||||
alertsIndexPattern: 'test-pattern',
|
||||
anonymizationFields,
|
||||
esClient,
|
||||
onNewReplacements,
|
||||
replacements: mockReplacements,
|
||||
size: 10,
|
||||
});
|
||||
});
|
||||
|
||||
it('handles empty anonymized alerts', async () => {
|
||||
(getAnonymizedAlerts as jest.Mock).mockResolvedValue([]);
|
||||
|
||||
const retriever = new AnonymizedAlertsRetriever({
|
||||
esClient,
|
||||
alertsIndexPattern: 'test-pattern',
|
||||
anonymizationFields,
|
||||
size: 10,
|
||||
});
|
||||
|
||||
const documents = await retriever._getRelevantDocuments('test-query');
|
||||
|
||||
expect(documents).toHaveLength(0);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,111 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { ElasticsearchClient, Logger } from '@kbn/core/server';
|
||||
import { elasticsearchClientMock } from '@kbn/core-elasticsearch-client-server-mocks';
|
||||
import { Replacements } from '@kbn/elastic-assistant-common';
|
||||
|
||||
import { getRetrieveAnonymizedAlertsNode } from '.';
|
||||
import { mockAnonymizedAlerts } from '../../../../evaluation/__mocks__/mock_anonymized_alerts';
|
||||
import { getDefaultAttackDiscoveryPrompt } from '../helpers/get_default_attack_discovery_prompt';
|
||||
import { getDefaultRefinePrompt } from '../refine/helpers/get_default_refine_prompt';
|
||||
import type { GraphState } from '../../types';
|
||||
|
||||
const initialGraphState: GraphState = {
|
||||
attackDiscoveries: null,
|
||||
attackDiscoveryPrompt: getDefaultAttackDiscoveryPrompt(),
|
||||
anonymizedAlerts: [],
|
||||
combinedGenerations: '',
|
||||
combinedRefinements: '',
|
||||
errors: [],
|
||||
generationAttempts: 0,
|
||||
generations: [],
|
||||
hallucinationFailures: 0,
|
||||
maxGenerationAttempts: 10,
|
||||
maxHallucinationFailures: 5,
|
||||
maxRepeatedGenerations: 3,
|
||||
refinements: [],
|
||||
refinePrompt: getDefaultRefinePrompt(),
|
||||
replacements: {},
|
||||
unrefinedResults: null,
|
||||
};
|
||||
|
||||
jest.mock('./anonymized_alerts_retriever', () => ({
|
||||
AnonymizedAlertsRetriever: jest
|
||||
.fn()
|
||||
.mockImplementation(
|
||||
({
|
||||
onNewReplacements,
|
||||
replacements,
|
||||
}: {
|
||||
onNewReplacements?: (replacements: Replacements) => void;
|
||||
replacements?: Replacements;
|
||||
}) => ({
|
||||
withConfig: jest.fn().mockReturnValue({
|
||||
invoke: jest.fn(async () => {
|
||||
if (onNewReplacements != null && replacements != null) {
|
||||
onNewReplacements(replacements);
|
||||
}
|
||||
|
||||
return mockAnonymizedAlerts;
|
||||
}),
|
||||
}),
|
||||
})
|
||||
),
|
||||
}));
|
||||
|
||||
describe('getRetrieveAnonymizedAlertsNode', () => {
|
||||
const logger = {
|
||||
debug: jest.fn(),
|
||||
} as unknown as Logger;
|
||||
|
||||
let esClient: ElasticsearchClient;
|
||||
|
||||
beforeEach(() => {
|
||||
esClient = elasticsearchClientMock.createScopedClusterClient().asCurrentUser;
|
||||
});
|
||||
|
||||
it('returns a function', () => {
|
||||
const result = getRetrieveAnonymizedAlertsNode({
|
||||
esClient,
|
||||
logger,
|
||||
});
|
||||
expect(typeof result).toBe('function');
|
||||
});
|
||||
|
||||
it('updates state with anonymized alerts', async () => {
|
||||
const state: GraphState = { ...initialGraphState };
|
||||
|
||||
const retrieveAnonymizedAlerts = getRetrieveAnonymizedAlertsNode({
|
||||
esClient,
|
||||
logger,
|
||||
});
|
||||
|
||||
const result = await retrieveAnonymizedAlerts(state);
|
||||
|
||||
expect(result).toHaveProperty('anonymizedAlerts', mockAnonymizedAlerts);
|
||||
});
|
||||
|
||||
it('calls onNewReplacements with updated replacements', async () => {
|
||||
const state: GraphState = { ...initialGraphState };
|
||||
const onNewReplacements = jest.fn();
|
||||
const replacements = { key: 'value' };
|
||||
|
||||
const retrieveAnonymizedAlerts = getRetrieveAnonymizedAlertsNode({
|
||||
esClient,
|
||||
logger,
|
||||
onNewReplacements,
|
||||
replacements,
|
||||
});
|
||||
|
||||
await retrieveAnonymizedAlerts(state);
|
||||
|
||||
expect(onNewReplacements).toHaveBeenCalledWith({
|
||||
...replacements,
|
||||
});
|
||||
});
|
||||
});
|
|
@ -60,11 +60,3 @@ export const getRetrieveAnonymizedAlertsNode = ({
|
|||
|
||||
return retrieveAnonymizedAlerts;
|
||||
};
|
||||
|
||||
/**
|
||||
* Retrieve documents
|
||||
*
|
||||
* @param {GraphState} state The current state of the graph.
|
||||
* @param {RunnableConfig | undefined} config The configuration object for tracing.
|
||||
* @returns {Promise<GraphState>} The new state object.
|
||||
*/
|
||||
|
|
|
@ -0,0 +1,115 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getDefaultGraphState } from '.';
|
||||
import {
|
||||
DEFAULT_MAX_GENERATION_ATTEMPTS,
|
||||
DEFAULT_MAX_HALLUCINATION_FAILURES,
|
||||
DEFAULT_MAX_REPEATED_GENERATIONS,
|
||||
} from '../constants';
|
||||
import { getDefaultAttackDiscoveryPrompt } from '../nodes/helpers/get_default_attack_discovery_prompt';
|
||||
import { getDefaultRefinePrompt } from '../nodes/refine/helpers/get_default_refine_prompt';
|
||||
|
||||
const defaultAttackDiscoveryPrompt = getDefaultAttackDiscoveryPrompt();
|
||||
const defaultRefinePrompt = getDefaultRefinePrompt();
|
||||
|
||||
describe('getDefaultGraphState', () => {
|
||||
it('returns the expected default attackDiscoveries', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.attackDiscoveries?.default?.()).toBeNull();
|
||||
});
|
||||
|
||||
it('returns the expected default attackDiscoveryPrompt', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.attackDiscoveryPrompt?.default?.()).toEqual(defaultAttackDiscoveryPrompt);
|
||||
});
|
||||
|
||||
it('returns the expected default empty collection of anonymizedAlerts', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.anonymizedAlerts?.default?.()).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('returns the expected default combinedGenerations state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.combinedGenerations?.default?.()).toBe('');
|
||||
});
|
||||
|
||||
it('returns the expected default combinedRefinements state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.combinedRefinements?.default?.()).toBe('');
|
||||
});
|
||||
|
||||
it('returns the expected default errors state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.errors?.default?.()).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('return the expected default generationAttempts state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.generationAttempts?.default?.()).toBe(0);
|
||||
});
|
||||
|
||||
it('returns the expected default generations state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.generations?.default?.()).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('returns the expected default hallucinationFailures state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.hallucinationFailures?.default?.()).toBe(0);
|
||||
});
|
||||
|
||||
it('returns the expected default refinePrompt state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.refinePrompt?.default?.()).toEqual(defaultRefinePrompt);
|
||||
});
|
||||
|
||||
it('returns the expected default maxGenerationAttempts state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.maxGenerationAttempts?.default?.()).toBe(DEFAULT_MAX_GENERATION_ATTEMPTS);
|
||||
});
|
||||
|
||||
it('returns the expected default maxHallucinationFailures state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
expect(state.maxHallucinationFailures?.default?.()).toBe(DEFAULT_MAX_HALLUCINATION_FAILURES);
|
||||
});
|
||||
|
||||
it('returns the expected default maxRepeatedGenerations state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.maxRepeatedGenerations?.default?.()).toBe(DEFAULT_MAX_REPEATED_GENERATIONS);
|
||||
});
|
||||
|
||||
it('returns the expected default refinements state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.refinements?.default?.()).toHaveLength(0);
|
||||
});
|
||||
|
||||
it('returns the expected default replacements state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.replacements?.default?.()).toEqual({});
|
||||
});
|
||||
|
||||
it('returns the expected default unrefinedResults state', () => {
|
||||
const state = getDefaultGraphState();
|
||||
|
||||
expect(state.unrefinedResults?.default?.()).toBeNull();
|
||||
});
|
||||
});
|
|
@ -0,0 +1,87 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { showEmptyStates } from '.';
|
||||
import {
|
||||
showEmptyPrompt,
|
||||
showFailurePrompt,
|
||||
showNoAlertsPrompt,
|
||||
showWelcomePrompt,
|
||||
} from '../../../helpers';
|
||||
|
||||
jest.mock('../../../helpers', () => ({
|
||||
showEmptyPrompt: jest.fn().mockReturnValue(false),
|
||||
showFailurePrompt: jest.fn().mockReturnValue(false),
|
||||
showNoAlertsPrompt: jest.fn().mockReturnValue(false),
|
||||
showWelcomePrompt: jest.fn().mockReturnValue(false),
|
||||
}));
|
||||
|
||||
const defaultArgs = {
|
||||
aiConnectorsCount: 0,
|
||||
alertsContextCount: 0,
|
||||
attackDiscoveriesCount: 0,
|
||||
connectorId: undefined,
|
||||
failureReason: null,
|
||||
isLoading: false,
|
||||
};
|
||||
|
||||
describe('showEmptyStates', () => {
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('returns true if showWelcomePrompt returns true', () => {
|
||||
(showWelcomePrompt as jest.Mock).mockReturnValue(true);
|
||||
|
||||
const result = showEmptyStates({
|
||||
...defaultArgs,
|
||||
});
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('returns true if showFailurePrompt returns true', () => {
|
||||
(showFailurePrompt as jest.Mock).mockReturnValue(true);
|
||||
|
||||
const result = showEmptyStates({
|
||||
...defaultArgs,
|
||||
connectorId: 'test',
|
||||
failureReason: 'error',
|
||||
});
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('returns true if showNoAlertsPrompt returns true', () => {
|
||||
(showNoAlertsPrompt as jest.Mock).mockReturnValue(true);
|
||||
|
||||
const result = showEmptyStates({
|
||||
...defaultArgs,
|
||||
connectorId: 'test',
|
||||
});
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('returns true if showEmptyPrompt returns true', () => {
|
||||
(showEmptyPrompt as jest.Mock).mockReturnValue(true);
|
||||
|
||||
const result = showEmptyStates({
|
||||
...defaultArgs,
|
||||
});
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false if all prompts return false', () => {
|
||||
(showWelcomePrompt as jest.Mock).mockReturnValue(false);
|
||||
(showFailurePrompt as jest.Mock).mockReturnValue(false);
|
||||
(showNoAlertsPrompt as jest.Mock).mockReturnValue(false);
|
||||
(showEmptyPrompt as jest.Mock).mockReturnValue(false);
|
||||
|
||||
const result = showEmptyStates({
|
||||
...defaultArgs,
|
||||
});
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import React from 'react';
|
||||
|
||||
import { Generate } from '.';
|
||||
import * as i18n from '../empty_prompt/translations';
|
||||
|
||||
describe('Generate Component', () => {
|
||||
it('calls onGenerate when the button is clicked', () => {
|
||||
const onGenerate = jest.fn();
|
||||
|
||||
render(<Generate isLoading={false} onGenerate={onGenerate} />);
|
||||
|
||||
fireEvent.click(screen.getByTestId('generate'));
|
||||
|
||||
expect(onGenerate).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('disables the generate button when isLoading is true', () => {
|
||||
render(<Generate isLoading={true} onGenerate={jest.fn()} />);
|
||||
|
||||
expect(screen.getByTestId('generate')).toBeDisabled();
|
||||
});
|
||||
|
||||
it('disables the generate button when isDisabled is true', () => {
|
||||
render(<Generate isLoading={false} isDisabled={true} onGenerate={jest.fn()} />);
|
||||
|
||||
expect(screen.getByTestId('generate')).toBeDisabled();
|
||||
});
|
||||
|
||||
it('shows tooltip content when the button is disabled', async () => {
|
||||
render(<Generate isLoading={false} isDisabled={true} onGenerate={jest.fn()} />);
|
||||
|
||||
fireEvent.mouseOver(screen.getByTestId('generate'));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(i18n.SELECT_A_CONNECTOR)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { render, screen, fireEvent } from '@testing-library/react';
|
||||
import React from 'react';
|
||||
|
||||
import { AlertsSettings, MAX_ALERTS } from '.';
|
||||
|
||||
const maxAlerts = '150';
|
||||
|
||||
const setMaxAlerts = jest.fn();
|
||||
|
||||
describe('AlertsSettings', () => {
|
||||
it('calls setMaxAlerts when the alerts range changes', () => {
|
||||
render(<AlertsSettings maxAlerts={maxAlerts} setMaxAlerts={setMaxAlerts} />);
|
||||
|
||||
fireEvent.click(screen.getByText(`${MAX_ALERTS}`));
|
||||
|
||||
expect(setMaxAlerts).toHaveBeenCalledWith(`${MAX_ALERTS}`);
|
||||
});
|
||||
|
||||
it('displays the correct maxAlerts value', () => {
|
||||
render(<AlertsSettings maxAlerts={maxAlerts} setMaxAlerts={setMaxAlerts} />);
|
||||
|
||||
expect(screen.getByTestId('alertsRange')).toHaveValue(maxAlerts);
|
||||
});
|
||||
|
||||
it('displays the expected text for anonymization settings', () => {
|
||||
render(<AlertsSettings maxAlerts={maxAlerts} setMaxAlerts={setMaxAlerts} />);
|
||||
|
||||
expect(screen.getByTestId('latestAndRiskiest')).toHaveTextContent(
|
||||
'Send Attack discovery information about your 150 newest and riskiest open or acknowledged alerts.'
|
||||
);
|
||||
});
|
||||
});
|
|
@ -51,7 +51,9 @@ const AlertsSettingsComponent: React.FC<Props> = ({ maxAlerts, setMaxAlerts }) =
|
|||
|
||||
<EuiFlexItem grow={true}>
|
||||
<EuiText color="subdued" size="xs">
|
||||
<span>{i18n.LATEST_AND_RISKIEST_OPEN_ALERTS(Number(maxAlerts))}</span>
|
||||
<span data-test-subj="latestAndRiskiest">
|
||||
{i18n.LATEST_AND_RISKIEST_OPEN_ALERTS(Number(maxAlerts))}
|
||||
</span>
|
||||
</EuiText>
|
||||
</EuiFlexItem>
|
||||
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import React from 'react';
|
||||
import { fireEvent, render, screen } from '@testing-library/react';
|
||||
|
||||
import { Footer } from '.';
|
||||
|
||||
describe('Footer', () => {
|
||||
const closeModal = jest.fn();
|
||||
const onReset = jest.fn();
|
||||
const onSave = jest.fn();
|
||||
|
||||
beforeEach(() => jest.clearAllMocks());
|
||||
|
||||
it('calls onReset when the reset button is clicked', () => {
|
||||
render(<Footer closeModal={closeModal} onReset={onReset} onSave={onSave} />);
|
||||
|
||||
fireEvent.click(screen.getByTestId('reset'));
|
||||
|
||||
expect(onReset).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('calls closeModal when the cancel button is clicked', () => {
|
||||
render(<Footer closeModal={closeModal} onReset={onReset} onSave={onSave} />);
|
||||
|
||||
fireEvent.click(screen.getByTestId('cancel'));
|
||||
|
||||
expect(closeModal).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('calls onSave when the save button is clicked', () => {
|
||||
render(<Footer closeModal={closeModal} onReset={onReset} onSave={onSave} />);
|
||||
fireEvent.click(screen.getByTestId('save'));
|
||||
|
||||
expect(onSave).toHaveBeenCalled();
|
||||
});
|
||||
});
|
|
@ -23,7 +23,7 @@ const FooterComponent: React.FC<Props> = ({ closeModal, onReset, onSave }) => {
|
|||
return (
|
||||
<EuiFlexGroup alignItems="center" gutterSize="none" justifyContent="spaceBetween">
|
||||
<EuiFlexItem grow={false}>
|
||||
<EuiButtonEmpty data-test-sub="reset" flush="both" onClick={onReset} size="s">
|
||||
<EuiButtonEmpty data-test-subj="reset" flush="both" onClick={onReset} size="s">
|
||||
{i18n.RESET}
|
||||
</EuiButtonEmpty>
|
||||
</EuiFlexItem>
|
||||
|
@ -36,13 +36,13 @@ const FooterComponent: React.FC<Props> = ({ closeModal, onReset, onSave }) => {
|
|||
`}
|
||||
grow={false}
|
||||
>
|
||||
<EuiButtonEmpty data-test-sub="cancel" onClick={closeModal} size="s">
|
||||
<EuiButtonEmpty data-test-subj="cancel" onClick={closeModal} size="s">
|
||||
{i18n.CANCEL}
|
||||
</EuiButtonEmpty>
|
||||
</EuiFlexItem>
|
||||
|
||||
<EuiFlexItem grow={false}>
|
||||
<EuiButton data-test-sub="save" fill onClick={onSave} size="s">
|
||||
<EuiButton data-test-subj="save" fill onClick={onSave} size="s">
|
||||
{i18n.SAVE}
|
||||
</EuiButton>
|
||||
</EuiFlexItem>
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { DEFAULT_ATTACK_DISCOVERY_MAX_ALERTS } from '@kbn/elastic-assistant';
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react';
|
||||
import React from 'react';
|
||||
|
||||
import { SettingsModal } from '.';
|
||||
import { MAX_ALERTS } from './alerts_settings';
|
||||
|
||||
const defaultProps = {
|
||||
connectorId: undefined,
|
||||
isLoading: false,
|
||||
localStorageAttackDiscoveryMaxAlerts: undefined,
|
||||
setLocalStorageAttackDiscoveryMaxAlerts: jest.fn(),
|
||||
};
|
||||
|
||||
describe('SettingsModal', () => {
|
||||
beforeEach(() => jest.clearAllMocks());
|
||||
|
||||
it('opens the modal when the settings button is clicked', () => {
|
||||
render(<SettingsModal {...defaultProps} />);
|
||||
|
||||
fireEvent.click(screen.getByTestId('settings'));
|
||||
|
||||
expect(screen.getByTestId('modal')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('closes the modal when the close button is clicked', () => {
|
||||
render(<SettingsModal {...defaultProps} />);
|
||||
|
||||
fireEvent.click(screen.getByTestId('settings'));
|
||||
expect(screen.getByTestId('modal')).toBeInTheDocument();
|
||||
|
||||
fireEvent.click(screen.getByTestId('cancel'));
|
||||
expect(screen.queryByTestId('modal')).not.toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calls onSave when save button is clicked', () => {
|
||||
render(<SettingsModal {...defaultProps} />);
|
||||
|
||||
fireEvent.click(screen.getByTestId('settings'));
|
||||
fireEvent.click(screen.getByText(`${MAX_ALERTS}`));
|
||||
|
||||
fireEvent.click(screen.getByTestId('save'));
|
||||
|
||||
expect(defaultProps.setLocalStorageAttackDiscoveryMaxAlerts).toHaveBeenCalledWith(
|
||||
`${MAX_ALERTS}`
|
||||
);
|
||||
});
|
||||
|
||||
it('resets max alerts to the default when the reset button is clicked', async () => {
|
||||
render(<SettingsModal {...defaultProps} />);
|
||||
|
||||
fireEvent.click(screen.getByTestId('settings'));
|
||||
|
||||
fireEvent.click(screen.getByText(`${MAX_ALERTS}`));
|
||||
await waitFor(() => expect(screen.getByTestId('alertsRange')).toHaveValue(`${MAX_ALERTS}`));
|
||||
|
||||
fireEvent.click(screen.getByTestId('reset'));
|
||||
|
||||
await waitFor(() =>
|
||||
expect(screen.getByTestId('alertsRange')).toHaveValue(
|
||||
`${DEFAULT_ATTACK_DISCOVERY_MAX_ALERTS}`
|
||||
)
|
||||
);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,76 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getIsTourEnabled } from '.';
|
||||
|
||||
describe('getIsTourEnabled', () => {
|
||||
it('returns true when all conditions are met', () => {
|
||||
const result = getIsTourEnabled({
|
||||
connectorId: 'test-connector-id',
|
||||
isLoading: false,
|
||||
tourDelayElapsed: true,
|
||||
showSettingsTour: true,
|
||||
});
|
||||
|
||||
expect(result).toBe(true);
|
||||
});
|
||||
|
||||
it('returns false when isLoading is true', () => {
|
||||
const result = getIsTourEnabled({
|
||||
connectorId: 'test-connector-id',
|
||||
isLoading: true, // <-- don't show the tour during loading
|
||||
tourDelayElapsed: true,
|
||||
showSettingsTour: true,
|
||||
});
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false when connectorId is undefined because it hasn't loaded from storage", () => {
|
||||
const result = getIsTourEnabled({
|
||||
connectorId: undefined, // <-- don't show the tour if there is no connectorId
|
||||
isLoading: false,
|
||||
tourDelayElapsed: true,
|
||||
showSettingsTour: true,
|
||||
});
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when tourDelayElapsed is false', () => {
|
||||
const result = getIsTourEnabled({
|
||||
connectorId: 'test-connector-id',
|
||||
isLoading: false,
|
||||
tourDelayElapsed: false, // <-- don't show the tour if the delay hasn't elapsed
|
||||
showSettingsTour: true,
|
||||
});
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it('returns false when showSettingsTour is false', () => {
|
||||
const result = getIsTourEnabled({
|
||||
connectorId: 'test-connector-id',
|
||||
isLoading: false,
|
||||
tourDelayElapsed: true,
|
||||
showSettingsTour: false, // <-- don't show the tour if it's disabled
|
||||
});
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
|
||||
it("returns false when showSettingsTour is undefined because it hasn't loaded from storage", () => {
|
||||
const result = getIsTourEnabled({
|
||||
connectorId: 'test-connector-id',
|
||||
isLoading: false,
|
||||
tourDelayElapsed: true,
|
||||
showSettingsTour: undefined, // <-- don't show the tour if it's undefined
|
||||
});
|
||||
|
||||
expect(result).toBe(false);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,83 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { getLoadingCalloutAlertsCount } from '.';
|
||||
|
||||
describe('getLoadingCalloutAlertsCount', () => {
|
||||
it('returns alertsContextCount when it is a positive number', () => {
|
||||
const alertsContextCount = 5; // <-- positive number
|
||||
|
||||
const result = getLoadingCalloutAlertsCount({
|
||||
alertsContextCount,
|
||||
defaultMaxAlerts: 10,
|
||||
localStorageAttackDiscoveryMaxAlerts: '15',
|
||||
});
|
||||
|
||||
expect(result).toBe(alertsContextCount);
|
||||
});
|
||||
|
||||
it('returns defaultMaxAlerts when localStorageAttackDiscoveryMaxAlerts is undefined', () => {
|
||||
const defaultMaxAlerts = 10;
|
||||
|
||||
const result = getLoadingCalloutAlertsCount({
|
||||
alertsContextCount: null,
|
||||
defaultMaxAlerts,
|
||||
localStorageAttackDiscoveryMaxAlerts: undefined, // <-- undefined
|
||||
});
|
||||
|
||||
expect(result).toBe(defaultMaxAlerts);
|
||||
});
|
||||
|
||||
it('returns defaultMaxAlerts when localStorageAttackDiscoveryMaxAlerts is NaN', () => {
|
||||
const defaultMaxAlerts = 10;
|
||||
|
||||
const result = getLoadingCalloutAlertsCount({
|
||||
alertsContextCount: 0, // <-- not a valid alertsContextCount
|
||||
defaultMaxAlerts,
|
||||
localStorageAttackDiscoveryMaxAlerts: 'NaN', // <-- NaN
|
||||
});
|
||||
|
||||
expect(result).toBe(defaultMaxAlerts);
|
||||
});
|
||||
|
||||
it('returns defaultMaxAlerts when localStorageAttackDiscoveryMaxAlerts is 0', () => {
|
||||
const defaultMaxAlerts = 10;
|
||||
|
||||
const result = getLoadingCalloutAlertsCount({
|
||||
alertsContextCount: 0, // <-- not a valid alertsContextCount
|
||||
defaultMaxAlerts,
|
||||
localStorageAttackDiscoveryMaxAlerts: '0', // <-- NaN
|
||||
});
|
||||
|
||||
expect(result).toBe(defaultMaxAlerts);
|
||||
});
|
||||
|
||||
it("returns size from localStorageAttackDiscoveryMaxAlerts when it's a positive number", () => {
|
||||
const localStorageAttackDiscoveryMaxAlerts = '15'; // <-- positive number
|
||||
|
||||
const result = getLoadingCalloutAlertsCount({
|
||||
alertsContextCount: null,
|
||||
defaultMaxAlerts: 10,
|
||||
localStorageAttackDiscoveryMaxAlerts,
|
||||
});
|
||||
|
||||
expect(result).toBe(15);
|
||||
});
|
||||
|
||||
it('returns defaultMaxAlerts when localStorageAttackDiscoveryMaxAlerts is negative', () => {
|
||||
const localStorageAttackDiscoveryMaxAlerts = '-5'; // <-- negative number
|
||||
const defaultMaxAlerts = 10;
|
||||
|
||||
const result = getLoadingCalloutAlertsCount({
|
||||
alertsContextCount: null,
|
||||
defaultMaxAlerts: 10,
|
||||
localStorageAttackDiscoveryMaxAlerts,
|
||||
});
|
||||
|
||||
expect(result).toBe(defaultMaxAlerts);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,88 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { render, screen, fireEvent } from '@testing-library/react';
|
||||
import React from 'react';
|
||||
|
||||
import { TestProviders } from '../../../common/mock';
|
||||
import { mockAttackDiscovery } from '../../mock/mock_attack_discovery';
|
||||
import { Results } from '.';
|
||||
|
||||
describe('Results', () => {
|
||||
const defaultProps = {
|
||||
aiConnectorsCount: 1,
|
||||
alertsContextCount: 100,
|
||||
alertsCount: 50,
|
||||
attackDiscoveriesCount: 1,
|
||||
connectorId: 'test-connector-id',
|
||||
failureReason: null,
|
||||
isLoading: false,
|
||||
isLoadingPost: false,
|
||||
localStorageAttackDiscoveryMaxAlerts: undefined,
|
||||
onGenerate: jest.fn(),
|
||||
onToggleShowAnonymized: jest.fn(),
|
||||
selectedConnectorAttackDiscoveries: [mockAttackDiscovery],
|
||||
selectedConnectorLastUpdated: new Date(),
|
||||
selectedConnectorReplacements: {},
|
||||
showAnonymized: false,
|
||||
};
|
||||
|
||||
it('renders the EmptyStates when showEmptyStates returns true', () => {
|
||||
render(
|
||||
<TestProviders>
|
||||
<Results {...defaultProps} aiConnectorsCount={0} />
|
||||
</TestProviders>
|
||||
);
|
||||
|
||||
expect(screen.getByTestId('welcome')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calls onGenerate when the generate button is clicked', () => {
|
||||
render(
|
||||
<TestProviders>
|
||||
<Results {...defaultProps} alertsContextCount={0} />
|
||||
</TestProviders>
|
||||
);
|
||||
|
||||
fireEvent.click(screen.getByTestId('generate'));
|
||||
|
||||
expect(defaultProps.onGenerate).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('renders the Summary when showSummary returns true', () => {
|
||||
render(
|
||||
<TestProviders>
|
||||
<Results {...defaultProps} />
|
||||
</TestProviders>
|
||||
);
|
||||
expect(screen.getByTestId('summary')).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it('calls onToggleShowAnonymized when the show anonymized toggle is clicked', () => {
|
||||
render(
|
||||
<TestProviders>
|
||||
<Results {...defaultProps} />
|
||||
</TestProviders>
|
||||
);
|
||||
|
||||
fireEvent.click(screen.getByTestId('toggleAnonymized'));
|
||||
|
||||
expect(defaultProps.onToggleShowAnonymized).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('renders a AttackDiscoveryPanel for the attack discovery', () => {
|
||||
render(
|
||||
<TestProviders>
|
||||
<Results {...defaultProps} />
|
||||
</TestProviders>
|
||||
);
|
||||
|
||||
expect(screen.getAllByTestId('attackDiscovery')).toHaveLength(
|
||||
defaultProps.selectedConnectorAttackDiscoveries.length
|
||||
);
|
||||
});
|
||||
});
|
Loading…
Add table
Add a link
Reference in a new issue