Writing of new risk score fields is controlled by feature flag

This is needed because we don't currently have a mechanism to update the
mappings for the risk score indices, and the downstream transform will
crash ES if those fields exist on docs in that index (mappings or no).

Once the above is resolved, we can basically just revert this commit.

This reuses the existing feature flag,
`entityAnalyticsAssetCriticalityEnabled`, to control whether our risk
engine will calculate and write these new fields.

It does this by placing knowledge of that flag inside the
AssetCriticalityService, and modifying behavior based on that within the
`calculateRiskScores` function.

In terms of types, the new fields have all been marked as optional. I
have not modified the API schema because a) they reflect the full
possible field list, and b) I don't believe anything is
currently consuming them.
This commit is contained in:
Ryland Herrick 2023-12-14 16:44:08 -06:00
parent 788b23a9b3
commit 71f115800b
9 changed files with 55 additions and 49 deletions

View file

@ -56,15 +56,15 @@ export interface RiskScore {
'@timestamp': string;
id_field: string;
id_value: string;
criticality_level: string | undefined;
criticality_modifier: number | undefined;
criticality_level?: string | undefined;
criticality_modifier?: number | undefined;
calculated_level: string;
calculated_score: number;
calculated_score_norm: number;
category_1_score: number;
category_1_count: number;
category_5_score: number;
category_5_count: number;
category_5_score?: number;
category_5_count?: number;
notes: string[];
inputs: RiskInputs;
}

View file

@ -10,7 +10,10 @@ import type { SearchHit } from '@elastic/elasticsearch/lib/api/types';
import type { AssetCriticalityRecord } from '../../../../common/api/entity_analytics';
import type { AssetCriticalityDataClient } from './asset_criticality_data_client';
import { assetCriticalityDataClientMock } from './asset_criticality_data_client.mock';
import { assetCriticalityServiceFactory } from './asset_criticality_service';
import {
type AssetCriticalityService,
assetCriticalityServiceFactory,
} from './asset_criticality_service';
const buildMockCriticalityHit = (
overrides: Partial<AssetCriticalityRecord> = {}
@ -30,6 +33,7 @@ describe('AssetCriticalityService', () => {
describe('#getCriticalitiesByIdentifiers()', () => {
let baseIdentifier: { id_field: string; id_value: string };
let mockAssetCriticalityDataClient: AssetCriticalityDataClient;
let service: AssetCriticalityService;
beforeEach(() => {
mockAssetCriticalityDataClient = assetCriticalityDataClientMock.create();
@ -38,13 +42,14 @@ describe('AssetCriticalityService', () => {
(mockAssetCriticalityDataClient.search as jest.Mock).mockResolvedValueOnce({
hits: { hits: [] },
});
service = assetCriticalityServiceFactory({
assetCriticalityDataClient: mockAssetCriticalityDataClient,
experimentalFeatures: {},
});
});
describe('specifying a single identifier', () => {
it('returns an empty response if identifier is not found', async () => {
const service = assetCriticalityServiceFactory({
assetCriticalityDataClient: mockAssetCriticalityDataClient,
});
const result = await service.getCriticalitiesByIdentifiers([baseIdentifier]);
expect(result).toEqual([]);
@ -56,9 +61,6 @@ describe('AssetCriticalityService', () => {
hits: { hits },
});
const service = assetCriticalityServiceFactory({
assetCriticalityDataClient: mockAssetCriticalityDataClient,
});
const result = await service.getCriticalitiesByIdentifiers([baseIdentifier]);
expect(result).toEqual(hits.map((hit) => hit._source));
@ -67,18 +69,12 @@ describe('AssetCriticalityService', () => {
describe('specifying multiple identifiers', () => {
it('returns an empty response if identifier is not found', async () => {
const service = assetCriticalityServiceFactory({
assetCriticalityDataClient: mockAssetCriticalityDataClient,
});
const result = await service.getCriticalitiesByIdentifiers([baseIdentifier]);
expect(result).toEqual([]);
});
it('generates a single terms clause for multiple identifier values on the same field', async () => {
const service = assetCriticalityServiceFactory({
assetCriticalityDataClient: mockAssetCriticalityDataClient,
});
const multipleIdentifiers = [
{ id_field: 'user.name', id_value: 'one' },
{ id_field: 'user.name', id_value: 'other' },
@ -109,9 +105,6 @@ describe('AssetCriticalityService', () => {
});
it('deduplicates identifiers', async () => {
const service = assetCriticalityServiceFactory({
assetCriticalityDataClient: mockAssetCriticalityDataClient,
});
const duplicateIdentifiers = [
{ id_field: 'user.name', id_value: 'same' },
{ id_field: 'user.name', id_value: 'same' },
@ -155,9 +148,6 @@ describe('AssetCriticalityService', () => {
hits,
},
});
const service = assetCriticalityServiceFactory({
assetCriticalityDataClient: mockAssetCriticalityDataClient,
});
const result = await service.getCriticalitiesByIdentifiers([baseIdentifier]);
expect(result).toEqual(hits.map((hit) => hit._source));
});
@ -165,18 +155,12 @@ describe('AssetCriticalityService', () => {
describe('arguments', () => {
it('accepts a single identifier as an array', async () => {
const service = assetCriticalityServiceFactory({
assetCriticalityDataClient: mockAssetCriticalityDataClient,
});
const identifier = { id_field: 'host.name', id_value: 'foo' };
expect(() => service.getCriticalitiesByIdentifiers([identifier])).not.toThrow();
});
it('accepts multiple identifiers', async () => {
const service = assetCriticalityServiceFactory({
assetCriticalityDataClient: mockAssetCriticalityDataClient,
});
const identifiers = [
{ id_field: 'host.name', id_value: 'foo' },
{ id_field: 'user.name', id_value: 'bar' },
@ -185,27 +169,18 @@ describe('AssetCriticalityService', () => {
});
it('throws an error if an empty array is provided', async () => {
const service = assetCriticalityServiceFactory({
assetCriticalityDataClient: mockAssetCriticalityDataClient,
});
await expect(() => service.getCriticalitiesByIdentifiers([])).rejects.toThrowError(
'At least one identifier must be provided'
);
});
it('throws an error if no identifier values are provided', async () => {
const service = assetCriticalityServiceFactory({
assetCriticalityDataClient: mockAssetCriticalityDataClient,
});
await expect(() =>
service.getCriticalitiesByIdentifiers([{ id_field: 'host.name', id_value: '' }])
).rejects.toThrowError('At least one identifier must contain a valid field and value');
});
it('throws an error if no valid identifier field/value pair is provided', async () => {
const service = assetCriticalityServiceFactory({
assetCriticalityDataClient: mockAssetCriticalityDataClient,
});
const identifiers = [
{ id_field: '', id_value: 'foo' },
{ id_field: 'user.name', id_value: '' },
@ -221,9 +196,6 @@ describe('AssetCriticalityService', () => {
(mockAssetCriticalityDataClient.search as jest.Mock)
.mockReset()
.mockRejectedValueOnce(new Error('foo'));
const service = assetCriticalityServiceFactory({
assetCriticalityDataClient: mockAssetCriticalityDataClient,
});
await expect(() =>
service.getCriticalitiesByIdentifiers([baseIdentifier])
).rejects.toThrowError('foo');

View file

@ -6,6 +6,7 @@
*/
import { isEmpty } from 'lodash/fp';
import type { ExperimentalFeatures } from '../../../../common';
import type { AssetCriticalityRecord } from '../../../../common/api/entity_analytics';
import type { AssetCriticalityDataClient } from './asset_criticality_data_client';
@ -22,6 +23,7 @@ export interface AssetCriticalityService {
getCriticalitiesByIdentifiers: (
identifiers: CriticalityIdentifier[]
) => Promise<AssetCriticalityRecord[]>;
isEnabled: () => boolean;
}
const isCriticalityIdentifierValid = (identifier: CriticalityIdentifier): boolean =>
@ -86,11 +88,14 @@ const getCriticalitiesByIdentifiers = async ({
interface AssetCriticalityServiceFactoryOptions {
assetCriticalityDataClient: AssetCriticalityDataClient;
experimentalFeatures: ExperimentalFeatures;
}
export const assetCriticalityServiceFactory = ({
assetCriticalityDataClient,
experimentalFeatures,
}: AssetCriticalityServiceFactoryOptions): AssetCriticalityService => ({
getCriticalitiesByIdentifiers: (identifiers: CriticalityIdentifier[]) =>
getCriticalitiesByIdentifiers({ assetCriticalityDataClient, identifiers }),
isEnabled: () => experimentalFeatures.entityAnalyticsAssetCriticalityEnabled,
});

View file

@ -63,11 +63,13 @@ const formatForResponse = ({
criticality,
now,
identifierField,
includeNewFields = false,
}: {
bucket: RiskScoreBucket;
criticality?: AssetCriticalityRecord;
now: string;
identifierField: string;
includeNewFields?: boolean;
}): RiskScore => {
const criticalityModifier = getCriticalityModifier(criticality?.criticality_level);
const normalizedScoreWithCriticality = applyCriticalityToScore({
@ -79,12 +81,17 @@ const formatForResponse = ({
normalizedScoreWithCriticality - bucket.risk_details.value.normalized_score;
const categoryFiveCount = criticalityModifier ? 1 : 0;
const newFields = {
category_5_score: categoryFiveScore,
category_5_count: categoryFiveCount,
criticality_level: criticality?.criticality_level,
criticality_modifier: criticalityModifier,
};
return {
'@timestamp': now,
id_field: identifierField,
id_value: bucket.key[identifierField],
criticality_level: criticality?.criticality_level,
criticality_modifier: criticalityModifier,
calculated_level: calculatedLevel,
calculated_score: bucket.risk_details.value.score,
calculated_score_norm: normalizedScoreWithCriticality,
@ -93,8 +100,6 @@ const formatForResponse = ({
max: RISK_SCORING_SUM_MAX,
}),
category_1_count: bucket.risk_details.value.category_1_count,
category_5_score: categoryFiveScore,
category_5_count: categoryFiveCount,
notes: bucket.risk_details.value.notes,
inputs: bucket.inputs.hits.hits.map((riskInput) => ({
id: riskInput._id,
@ -106,6 +111,7 @@ const formatForResponse = ({
risk_score: riskInput.fields?.[ALERT_RISK_SCORE]?.[0] ?? undefined,
timestamp: riskInput.fields?.['@timestamp']?.[0] ?? undefined,
})),
...(includeNewFields ? newFields : {}),
};
};
@ -235,6 +241,12 @@ const processScores = async ({
return [];
}
if (!assetCriticalityService.isEnabled()) {
return buckets.map((bucket) =>
formatForResponse({ bucket, now, identifierField, includeNewFields: false })
);
}
const identifiers = buckets.map((bucket) => ({
id_field: identifierField,
id_value: bucket.key[identifierField],

View file

@ -14,13 +14,18 @@ import {
RISK_SCORE_CALCULATION_URL,
} from '../../../../../common/constants';
import { riskScoreCalculationRequestSchema } from '../../../../../common/entity_analytics/risk_engine/risk_score_calculation/request_schema';
import type { ExperimentalFeatures } from '../../../../../common';
import type { SecuritySolutionPluginRouter } from '../../../../types';
import { buildRouteValidation } from '../../../../utils/build_validation/route_validation';
import { assetCriticalityServiceFactory } from '../../asset_criticality';
import { riskScoreServiceFactory } from '../risk_score_service';
import { getRiskInputsIndex } from '../get_risk_inputs_index';
export const riskScoreCalculationRoute = (router: SecuritySolutionPluginRouter, logger: Logger) => {
export const riskScoreCalculationRoute = (
router: SecuritySolutionPluginRouter,
logger: Logger,
experimentalFeatures: ExperimentalFeatures
) => {
router.versioned
.post({
path: RISK_SCORE_CALCULATION_URL,
@ -46,6 +51,7 @@ export const riskScoreCalculationRoute = (router: SecuritySolutionPluginRouter,
const assetCriticalityDataClient = securityContext.getAssetCriticalityDataClient();
const assetCriticalityService = assetCriticalityServiceFactory({
assetCriticalityDataClient,
experimentalFeatures,
});
const riskScoreService = riskScoreServiceFactory({

View file

@ -15,13 +15,18 @@ import {
RISK_SCORE_PREVIEW_URL,
} from '../../../../../common/constants';
import { riskScorePreviewRequestSchema } from '../../../../../common/entity_analytics/risk_engine/risk_score_preview/request_schema';
import type { ExperimentalFeatures } from '../../../../../common';
import type { SecuritySolutionPluginRouter } from '../../../../types';
import { buildRouteValidation } from '../../../../utils/build_validation/route_validation';
import { assetCriticalityServiceFactory } from '../../asset_criticality';
import { riskScoreServiceFactory } from '../risk_score_service';
import { getRiskInputsIndex } from '../get_risk_inputs_index';
export const riskScorePreviewRoute = (router: SecuritySolutionPluginRouter, logger: Logger) => {
export const riskScorePreviewRoute = (
router: SecuritySolutionPluginRouter,
logger: Logger,
experimentalFeatures: ExperimentalFeatures
) => {
router.versioned
.post({
access: 'internal',
@ -47,6 +52,7 @@ export const riskScorePreviewRoute = (router: SecuritySolutionPluginRouter, logg
const assetCriticalityDataClient = securityContext.getAssetCriticalityDataClient();
const assetCriticalityService = assetCriticalityServiceFactory({
assetCriticalityDataClient,
experimentalFeatures,
});
const riskScoreService = riskScoreServiceFactory({

View file

@ -35,6 +35,7 @@ import {
} from './state';
import { INTERVAL, SCOPE, TIMEOUT, TYPE, VERSION } from './constants';
import { buildScopedInternalSavedObjectsClientUnsafe, convertRangeToISO } from './helpers';
import type { ExperimentalFeatures } from '../../../../../common';
import {
RISK_SCORE_EXECUTION_SUCCESS_EVENT,
RISK_SCORE_EXECUTION_ERROR_EVENT,
@ -57,12 +58,14 @@ const getTaskId = (namespace: string): string => `${TYPE}:${namespace}:${VERSION
type GetRiskScoreService = (namespace: string) => Promise<RiskScoreService>;
export const registerRiskScoringTask = ({
experimentalFeatures,
getStartServices,
kibanaVersion,
logger,
taskManager,
telemetry,
}: {
experimentalFeatures: ExperimentalFeatures;
getStartServices: StartServicesAccessor<StartPlugins>;
kibanaVersion: string;
logger: Logger;
@ -86,6 +89,7 @@ export const registerRiskScoringTask = ({
});
const assetCriticalityService = assetCriticalityServiceFactory({
assetCriticalityDataClient,
experimentalFeatures,
});
const riskEngineDataClient = new RiskEngineDataClient({

View file

@ -182,6 +182,7 @@ export class Plugin implements ISecuritySolutionPlugin {
if (experimentalFeatures.riskScoringPersistence) {
registerRiskScoringTask({
experimentalFeatures,
getStartServices: core.getStartServices,
kibanaVersion: pluginContext.env.packageInfo.version,
logger: this.logger,

View file

@ -159,8 +159,8 @@ export const initRoutes = (
}
if (config.experimentalFeatures.riskScoringRoutesEnabled) {
riskScorePreviewRoute(router, logger);
riskScoreCalculationRoute(router, logger);
riskScorePreviewRoute(router, logger, config.experimentalFeatures);
riskScoreCalculationRoute(router, logger, config.experimentalFeatures);
riskEngineStatusRoute(router);
riskEngineInitRoute(router, getStartServices);
riskEngineEnableRoute(router, getStartServices);