[Security Solution][Siem migrations] Implement rate limit backoff (#211469)

## Summary

Implements an exponential backoff retry strategy when the LLM API throws
rate limit (`429`) errors.

### Backoff implementation

- The `run` method from the `RuleMigrationsTaskClient` has been moved to
the new `RuleMigrationTaskRunner` class.
- The settings for the backoff are defined in this class with:
```ts
/** Exponential backoff configuration to handle rate limit errors */
const RETRY_CONFIG = {
  initialRetryDelaySeconds: 1,
  backoffMultiplier: 2,
  maxRetries: 8,
  // max waiting time 4m15s (1*2^8 = 256s)
} as const;
```
- Only one rule will be retried at a time, the rest of the concurrent
rule translations blocked by the rate limit will await for the API to
recover before attempting the translation again.

```ts
/** Executor sleep configuration
 * A sleep time applied at the beginning of each single rule translation in the execution pool,
 * The objective of this sleep is to spread the load of concurrent translations, and prevent hitting the rate limit repeatedly.
 * The sleep time applied is a random number between [0-value]. Every time we hit rate limit the value is increased by the multiplier, up to the limit.
 */
const EXECUTOR_SLEEP = {
  initialValueSeconds: 3,
  multiplier: 2,
  limitSeconds: 96, // 1m36s (5 increases)
} as const;
```

### Migration batching changes

```ts
/** Number of concurrent rule translations in the pool */
const TASK_CONCURRENCY = 10 as const;
/** Number of rules loaded in memory to be translated in the pool */
const TASK_BATCH_SIZE = 100 as const;
```

#### Before 

- Batches of 15 rules were retrieved and executed in a `Promise.all`,
requiring all of them to be completed before proceeding to the next
batch.
- A "batch sleep" of 10s was executed at the end of each iteration.

#### In this PR

- Batches of 100 rules are retrieved and kept in memory. The execution
is performed in a task pool with a concurrency of 10 rules. This ensures
there are always 10 rules executing at a time.
- The "batch sleep" has been removed in favour of an "execution sleep"
of rand[1-3]s at the start of each single rule migration. This
individual sleep serves two goals:
  - Spread the load when the migration is first launched.
- Prevent hitting the rate limit consistently: The sleep duration is
increased every time we hit a rate limit.

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Sergi Massaneda 2025-02-21 20:54:40 +01:00 committed by GitHub
parent de7d33dec2
commit 64426b2b4d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 928 additions and 336 deletions

View file

@ -5,16 +5,16 @@
* 2.0.
*/
import { mockRuleMigrationsDataClient } from '../data/__mocks__/mocks';
import { mockRuleMigrationsTaskClient } from '../task/__mocks__/mocks';
import { createRuleMigrationsDataClientMock } from '../data/__mocks__/mocks';
import { createRuleMigrationsTaskClientMock } from '../task/__mocks__/mocks';
export const createRuleMigrationDataClient = jest
.fn()
.mockImplementation(() => mockRuleMigrationsDataClient);
.mockImplementation(() => createRuleMigrationsDataClientMock());
export const createRuleMigrationTaskClient = jest
.fn()
.mockImplementation(() => mockRuleMigrationsTaskClient);
.mockImplementation(() => createRuleMigrationsTaskClientMock());
export const createRuleMigrationClient = () => ({
data: createRuleMigrationDataClient(),

View file

@ -5,24 +5,28 @@
* 2.0.
*/
import type { RuleMigrationsDataIntegrationsClient } from '../rule_migrations_data_integrations_client';
import type { RuleMigrationsDataLookupsClient } from '../rule_migrations_data_lookups_client';
import type { RuleMigrationsDataPrebuiltRulesClient } from '../rule_migrations_data_prebuilt_rules_client';
import type { RuleMigrationsDataResourcesClient } from '../rule_migrations_data_resources_client';
import type { RuleMigrationsDataRulesClient } from '../rule_migrations_data_rules_client';
// Rule migrations data rules client
export const mockRuleMigrationsDataRulesClient = {
create: jest.fn().mockResolvedValue(undefined),
get: jest.fn().mockResolvedValue([]),
get: jest.fn().mockResolvedValue({ data: [], total: 0 }),
searchBatches: jest.fn().mockReturnValue({
next: jest.fn().mockResolvedValue([]),
all: jest.fn().mockResolvedValue([]),
}),
takePending: jest.fn().mockResolvedValue([]),
saveProcessing: jest.fn().mockResolvedValue(undefined),
saveCompleted: jest.fn().mockResolvedValue(undefined),
saveError: jest.fn().mockResolvedValue(undefined),
releaseProcessing: jest.fn().mockResolvedValue(undefined),
updateStatus: jest.fn().mockResolvedValue(undefined),
getStats: jest.fn().mockResolvedValue(undefined),
getAllStats: jest.fn().mockResolvedValue([]),
} as unknown as RuleMigrationsDataRulesClient;
} as unknown as jest.Mocked<RuleMigrationsDataRulesClient>;
export const MockRuleMigrationsDataRulesClient = jest
.fn()
.mockImplementation(() => mockRuleMigrationsDataRulesClient);
@ -35,30 +39,42 @@ export const mockRuleMigrationsDataResourcesClient = {
next: jest.fn().mockResolvedValue([]),
all: jest.fn().mockResolvedValue([]),
}),
};
} as unknown as jest.Mocked<RuleMigrationsDataResourcesClient>;
export const MockRuleMigrationsDataResourcesClient = jest
.fn()
.mockImplementation(() => mockRuleMigrationsDataResourcesClient);
export const mockRuleMigrationsDataIntegrationsClient = {
populate: jest.fn().mockResolvedValue(undefined),
retrieveIntegrations: jest.fn().mockResolvedValue([]),
};
} as unknown as jest.Mocked<RuleMigrationsDataIntegrationsClient>;
export const mockRuleMigrationsDataPrebuiltRulesClient = {
populate: jest.fn().mockResolvedValue(undefined),
search: jest.fn().mockResolvedValue([]),
} as unknown as jest.Mocked<RuleMigrationsDataPrebuiltRulesClient>;
export const mockRuleMigrationsDataLookupsClient = {
create: jest.fn().mockResolvedValue(undefined),
indexData: jest.fn().mockResolvedValue(undefined),
} as unknown as jest.Mocked<RuleMigrationsDataLookupsClient>;
// Rule migrations data client
export const mockRuleMigrationsDataClient = {
export const createRuleMigrationsDataClientMock = () => ({
rules: mockRuleMigrationsDataRulesClient,
resources: mockRuleMigrationsDataResourcesClient,
integrations: mockRuleMigrationsDataIntegrationsClient,
};
prebuiltRules: mockRuleMigrationsDataPrebuiltRulesClient,
lookups: mockRuleMigrationsDataLookupsClient,
});
export const MockRuleMigrationsDataClient = jest
.fn()
.mockImplementation(() => mockRuleMigrationsDataClient);
.mockImplementation(() => createRuleMigrationsDataClientMock());
// Rule migrations data service
export const mockIndexName = 'mocked_siem_rule_migrations_index_name';
export const mockInstall = jest.fn().mockResolvedValue(undefined);
export const mockCreateClient = jest.fn().mockReturnValue(mockRuleMigrationsDataClient);
export const mockCreateClient = jest.fn(() => createRuleMigrationsDataClientMock());
export const MockRuleMigrationsDataService = jest.fn().mockImplementation(() => ({
createAdapter: jest.fn(),

View file

@ -151,45 +151,19 @@ export class RuleMigrationsDataRulesClient extends RuleMigrationsDataBaseClient
}
}
/**
* Retrieves `pending` rule migrations with the provided id and updates their status to `processing`.
* This operation is not atomic at migration level:
* - Multiple tasks can process different migrations simultaneously.
* - Multiple tasks should not process the same migration simultaneously.
*/
async takePending(migrationId: string, size: number): Promise<StoredRuleMigration[]> {
/** Updates one rule migration status to `processing` */
async saveProcessing(id: string): Promise<void> {
const index = await this.getIndexName();
const profileId = await this.getProfileUid();
const query = this.getFilterQuery(migrationId, { status: SiemMigrationStatus.PENDING });
const storedRuleMigrations = await this.esClient
.search<RuleMigration>({ index, query, sort: '_doc', size })
.then((response) =>
this.processResponseHits(response, { status: SiemMigrationStatus.PROCESSING })
)
.catch((error) => {
this.logger.error(`Error searching rule migrations: ${error.message}`);
throw error;
});
await this.esClient
.bulk({
refresh: 'wait_for',
operations: storedRuleMigrations.flatMap(({ id, status }) => [
{ update: { _id: id, _index: index } },
{
doc: { status, updated_by: profileId, updated_at: new Date().toISOString() },
},
]),
})
.catch((error) => {
this.logger.error(
`Error updating for rule migrations status to processing: ${error.message}`
);
throw error;
});
return storedRuleMigrations;
const doc = {
status: SiemMigrationStatus.PROCESSING,
updated_by: profileId,
updated_at: new Date().toISOString(),
};
await this.esClient.update({ index, id, doc, refresh: 'wait_for' }).catch((error) => {
this.logger.error(`Error updating rule migration status to processing: ${error.message}`);
throw error;
});
}
/** Updates one rule migration with the provided data and sets the status to `completed` */

View file

@ -5,7 +5,7 @@
* 2.0.
*/
export const mockRuleMigrationsTaskClient = {
export const createRuleMigrationsTaskClientMock = () => ({
start: jest.fn().mockResolvedValue({ started: true }),
stop: jest.fn().mockResolvedValue({ stopped: true }),
getStats: jest.fn().mockResolvedValue({
@ -19,15 +19,15 @@ export const mockRuleMigrationsTaskClient = {
},
}),
getAllStats: jest.fn().mockResolvedValue([]),
};
});
export const MockRuleMigrationsTaskClient = jest
.fn()
.mockImplementation(() => mockRuleMigrationsTaskClient);
.mockImplementation(() => createRuleMigrationsTaskClientMock());
// Rule migrations task service
export const mockStopAll = jest.fn();
export const mockCreateClient = jest.fn().mockReturnValue(mockRuleMigrationsTaskClient);
export const mockCreateClient = jest.fn(() => createRuleMigrationsTaskClientMock());
export const MockRuleMigrationsTaskService = jest.fn().mockImplementation(() => ({
createClient: mockCreateClient,

View file

@ -18,6 +18,11 @@ export interface RuleMigrationsRetrieverClients {
savedObjects: SavedObjectsClientContract;
}
/**
* RuleMigrationsRetriever is a class that is responsible for retrieving all the necessary data during the rule migration process.
* It is composed of multiple retrievers that are responsible for retrieving specific types of data.
* Such as rule integrations, prebuilt rules, and rule resources.
*/
export class RuleMigrationsRetriever {
public readonly resources: RuleResourceRetriever;
public readonly integrations: IntegrationRetriever;

View file

@ -6,10 +6,6 @@
*/
import type { AuthenticatedUser, Logger } from '@kbn/core/server';
import { AbortError, abortSignalToPromise } from '@kbn/kibana-utils-plugin/server';
import type { RunnableConfig } from '@langchain/core/runnables';
import { TELEMETRY_SIEM_MIGRATION_ID } from './util/constants';
import { EsqlKnowledgeBase } from './util/esql_knowledge_base';
import {
SiemMigrationStatus,
SiemMigrationTaskStatus,
@ -19,26 +15,14 @@ import type { RuleMigrationFilters } from '../../../../../common/siem_migrations
import type { RuleMigrationsDataClient } from '../data/rule_migrations_data_client';
import type { RuleMigrationDataStats } from '../data/rule_migrations_data_rules_client';
import type { SiemRuleMigrationsClientDependencies } from '../types';
import { getRuleMigrationAgent } from './agent';
import type { MigrateRuleState } from './agent/types';
import { RuleMigrationsRetriever } from './retrievers';
import { SiemMigrationTelemetryClient } from './rule_migrations_telemetry_client';
import type {
MigrationAgent,
RuleMigrationTaskCreateAgentParams,
RuleMigrationTaskRunParams,
RuleMigrationTaskStartParams,
RuleMigrationTaskStartResult,
RuleMigrationTaskStopResult,
} from './types';
import type { ChatModel } from './util/actions_client_chat';
import { ActionsClientChat } from './util/actions_client_chat';
import { generateAssistantComment } from './util/comments';
import { RuleMigrationTaskRunner } from './rule_migrations_task_runner';
const ITERATION_BATCH_SIZE = 15 as const;
const ITERATION_SLEEP_SECONDS = 10 as const;
type MigrationsRunning = Map<string, { user: string; abortController: AbortController }>;
export type MigrationsRunning = Map<string, RuleMigrationTaskRunner>;
export class RuleMigrationsTaskClient {
constructor(
@ -51,7 +35,7 @@ export class RuleMigrationsTaskClient {
/** Starts a rule migration task */
async start(params: RuleMigrationTaskStartParams): Promise<RuleMigrationTaskStartResult> {
const { migrationId, connectorId } = params;
const { migrationId, connectorId, invocationConfig } = params;
if (this.migrationsRunning.has(migrationId)) {
return { exists: true, started: false };
}
@ -70,172 +54,52 @@ export class RuleMigrationsTaskClient {
if (rules.pending === 0) {
return { exists: true, started: false };
}
const abortController = new AbortController();
const model = await this.createModel(connectorId, migrationId, abortController);
// run the migration without awaiting it to execute it in the background
this.run({ ...params, model, abortController }).catch((error) => {
this.logger.error(`Error executing migration ID:${migrationId} with error ${error}`);
});
const migrationLogger = this.logger.get(migrationId);
const abortController = new AbortController();
const migrationTaskRunner = new RuleMigrationTaskRunner(
migrationId,
this.currentUser,
abortController,
this.data,
migrationLogger,
this.dependencies
);
await migrationTaskRunner.setup(connectorId);
if (this.migrationsRunning.has(migrationId)) {
// Just to prevent a race condition in the setup
throw new Error('Task already running for this migration');
}
this.migrationsRunning.set(migrationId, migrationTaskRunner);
migrationLogger.info('Starting migration');
// run the migration in the background without awaiting and resolve the `start` promise
migrationTaskRunner
.run(invocationConfig)
.catch((error) => {
// no need to throw, the `start` promise is long gone. Just log the error
migrationLogger.error('Error executing migration', error);
})
.finally(() => {
this.migrationsRunning.delete(migrationId);
});
return { exists: true, started: true };
}
private async run(params: RuleMigrationTaskRunParams): Promise<void> {
const { migrationId, invocationConfig, abortController, model } = params;
if (this.migrationsRunning.has(migrationId)) {
// This should never happen, but just in case
throw new Error(`Task already running for migration ID:${migrationId} `);
}
this.logger.info(`Starting migration ID:${migrationId}`);
this.migrationsRunning.set(migrationId, { user: this.currentUser.username, abortController });
const abortPromise = abortSignalToPromise(abortController.signal);
const withAbortRace = async <T>(task: Promise<T>) => Promise.race([task, abortPromise.promise]);
const sleep = async (seconds: number) => {
this.logger.debug(`Sleeping ${seconds}s for migration ID:${migrationId}`);
await withAbortRace(new Promise((resolve) => setTimeout(resolve, seconds * 1000)));
};
const stats = { completed: 0, failed: 0 };
const telemetryClient = new SiemMigrationTelemetryClient(
this.dependencies.telemetry,
this.logger,
migrationId,
model.model
);
const endSiemMigration = telemetryClient.startSiemMigration();
try {
this.logger.debug(`Creating agent for migration ID:${migrationId}`);
const agent = await withAbortRace(this.createAgent({ ...params, model, telemetryClient }));
const config: RunnableConfig = {
...invocationConfig,
// signal: abortController.signal, // not working properly https://github.com/langchain-ai/langgraphjs/issues/319
};
let isDone: boolean = false;
do {
const ruleMigrations = await this.data.rules.takePending(migrationId, ITERATION_BATCH_SIZE);
this.logger.debug(
`Processing ${ruleMigrations.length} rules for migration ID:${migrationId}`
);
await Promise.all(
ruleMigrations.map(async (ruleMigration) => {
this.logger.debug(`Starting migration of rule "${ruleMigration.original_rule.title}"`);
if (ruleMigration.elastic_rule?.id) {
await this.data.rules.saveCompleted(ruleMigration);
return; // skip already installed rules
}
const endRuleTranslation = telemetryClient.startRuleTranslation();
try {
const invocationData = {
original_rule: ruleMigration.original_rule,
};
// using withAbortRace is a workaround for the issue with the langGraph signal not working properly
const migrationResult = await withAbortRace<MigrateRuleState>(
agent.invoke(invocationData, config)
);
this.logger.debug(
`Migration of rule "${ruleMigration.original_rule.title}" finished`
);
endRuleTranslation({ migrationResult });
await this.data.rules.saveCompleted({
...ruleMigration,
elastic_rule: migrationResult.elastic_rule,
translation_result: migrationResult.translation_result,
comments: migrationResult.comments,
});
stats.completed++;
} catch (error) {
stats.failed++;
if (error instanceof AbortError) {
throw error;
}
endRuleTranslation({ error });
this.logger.error(
`Error migrating rule "${ruleMigration.original_rule.title} with error: ${error.message}"`
);
await this.data.rules.saveError({
...ruleMigration,
comments: [generateAssistantComment(`Error migrating rule: ${error.message}`)],
});
}
})
);
this.logger.debug(`Batch processed successfully for migration ID:${migrationId}`);
const { rules } = await this.data.rules.getStats(migrationId);
isDone = rules.pending === 0;
if (!isDone) {
await sleep(ITERATION_SLEEP_SECONDS);
}
} while (!isDone);
this.logger.info(`Finished migration ID:${migrationId}`);
endSiemMigration({ stats });
} catch (error) {
await this.data.rules.releaseProcessing(migrationId);
if (error instanceof AbortError) {
this.logger.info(`Abort signal received, stopping migration ID:${migrationId}`);
return;
} else {
endSiemMigration({ error, stats });
this.logger.error(`Error processing migration ID:${migrationId} ${error}`);
}
} finally {
this.migrationsRunning.delete(migrationId);
abortPromise.cleanup();
}
}
private async createAgent({
connectorId,
migrationId,
model,
telemetryClient,
}: RuleMigrationTaskCreateAgentParams): Promise<MigrationAgent> {
const { inferenceClient, rulesClient, savedObjectsClient } = this.dependencies;
const esqlKnowledgeBase = new EsqlKnowledgeBase(
connectorId,
migrationId,
inferenceClient,
this.logger
);
const ruleMigrationsRetriever = new RuleMigrationsRetriever(migrationId, {
data: this.data,
rules: rulesClient,
savedObjects: savedObjectsClient,
});
await ruleMigrationsRetriever.initialize();
return getRuleMigrationAgent({
model,
esqlKnowledgeBase,
ruleMigrationsRetriever,
telemetryClient,
logger: this.logger,
});
}
/** Updates all the rules in a migration to be re-executed */
public async updateToRetry(
migrationId: string,
filter: RuleMigrationFilters
): Promise<{ updated: boolean }> {
if (this.migrationsRunning.has(migrationId)) {
// not update migrations that are currently running
return { updated: false };
}
filter.installed = false; // only retry rules that are not installed
await this.data.rules.updateStatus(migrationId, filter, SiemMigrationStatus.PENDING, {
refresh: true,
});
@ -293,20 +157,4 @@ export class RuleMigrationsTaskClient {
return { exists: true, stopped: false };
}
}
private async createModel(
connectorId: string,
migrationId: string,
abortController: AbortController
): Promise<ChatModel> {
const { actionsClient } = this.dependencies;
const actionsClientChat = new ActionsClientChat(connectorId, actionsClient, this.logger);
const model = await actionsClientChat.createModel({
telemetryMetadata: { pluginId: TELEMETRY_SIEM_MIGRATION_ID, aggregateBy: migrationId },
maxRetries: 10,
signal: abortController.signal,
temperature: 0.05,
});
return model;
}
}

View file

@ -0,0 +1,383 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { RuleMigrationTaskRunner } from './rule_migrations_task_runner';
import { SiemMigrationStatus } from '../../../../../common/siem_migrations/constants';
import type { AuthenticatedUser } from '@kbn/core/server';
import type { SiemRuleMigrationsClientDependencies, StoredRuleMigration } from '../types';
import { createRuleMigrationsDataClientMock } from '../data/__mocks__/mocks';
import { loggerMock } from '@kbn/logging-mocks';
const mockRetrieverInitialize = jest.fn().mockResolvedValue(undefined);
jest.mock('./retrievers', () => ({
...jest.requireActual('./retrievers'),
RuleMigrationsRetriever: jest
.fn()
.mockImplementation(() => ({ initialize: mockRetrieverInitialize })),
}));
const mockCreateModel = jest.fn(() => ({ model: 'test-model' }));
jest.mock('./util/actions_client_chat', () => ({
...jest.requireActual('./util/actions_client_chat'),
ActionsClientChat: jest.fn().mockImplementation(() => ({ createModel: mockCreateModel })),
}));
const mockInvoke = jest.fn().mockResolvedValue({});
jest.mock('./agent', () => ({
...jest.requireActual('./agent'),
getRuleMigrationAgent: () => ({ invoke: mockInvoke }),
}));
jest.mock('./rule_migrations_telemetry_client', () => ({
SiemMigrationTelemetryClient: jest.fn().mockImplementation(() => ({
startSiemMigrationTask: jest.fn(() => ({
startRuleTranslation: jest.fn(() => ({ success: jest.fn(), failure: jest.fn() })),
success: jest.fn(),
failure: jest.fn(),
})),
})),
}));
// Mock dependencies
const mockLogger = loggerMock.create();
const mockDependencies: jest.Mocked<SiemRuleMigrationsClientDependencies> = {
rulesClient: {},
savedObjectsClient: {},
inferenceClient: {},
actionsClient: {},
telemetry: {},
} as unknown as SiemRuleMigrationsClientDependencies;
const mockUser = {} as unknown as AuthenticatedUser;
const ruleId = 'test-rule-id';
jest.useFakeTimers();
jest.spyOn(global, 'setTimeout');
const mockTimeout = setTimeout as unknown as jest.Mock;
mockTimeout.mockImplementation((cb) => {
// never actually wait, we'll check the calls manually
cb();
});
describe('RuleMigrationTaskRunner', () => {
let taskRunner: RuleMigrationTaskRunner;
let abortController: AbortController;
let mockRuleMigrationsDataClient: ReturnType<typeof createRuleMigrationsDataClientMock>;
beforeEach(() => {
mockRetrieverInitialize.mockResolvedValue(undefined); // Reset the mock
mockInvoke.mockResolvedValue({}); // Reset the mock
mockRuleMigrationsDataClient = createRuleMigrationsDataClientMock();
jest.clearAllMocks();
abortController = new AbortController();
taskRunner = new RuleMigrationTaskRunner(
'test-migration-id',
mockUser,
abortController,
mockRuleMigrationsDataClient,
mockLogger,
mockDependencies
);
});
describe('setup', () => {
it('should create the agent and tools', async () => {
await expect(taskRunner.setup('test-connector-id')).resolves.toBeUndefined();
// @ts-expect-error (checking private properties)
expect(taskRunner.agent).toBeDefined();
// @ts-expect-error (checking private properties)
expect(taskRunner.retriever).toBeDefined();
// @ts-expect-error (checking private properties)
expect(taskRunner.telemetry).toBeDefined();
});
it('should throw if an error occurs', async () => {
const errorMessage = 'Test error';
mockCreateModel.mockImplementationOnce(() => {
throw new Error(errorMessage);
});
await expect(taskRunner.setup('test-connector-id')).rejects.toThrowError(errorMessage);
});
});
describe('run', () => {
let runPromise: Promise<void>;
beforeEach(async () => {
await taskRunner.setup('test-connector-id');
});
it('should handle the migration successfully', async () => {
mockRuleMigrationsDataClient.rules.get.mockResolvedValue({ total: 0, data: [] });
mockRuleMigrationsDataClient.rules.get.mockResolvedValueOnce({
total: 1,
data: [{ id: ruleId, status: SiemMigrationStatus.PENDING }] as StoredRuleMigration[],
});
await taskRunner.setup('test-connector-id');
await expect(taskRunner.run({})).resolves.toBeUndefined();
expect(mockRuleMigrationsDataClient.rules.saveProcessing).toHaveBeenCalled();
expect(mockTimeout).toHaveBeenCalledTimes(1); // execution sleep
expect(mockInvoke).toHaveBeenCalledTimes(1);
expect(mockRuleMigrationsDataClient.rules.saveCompleted).toHaveBeenCalled();
expect(mockRuleMigrationsDataClient.rules.get).toHaveBeenCalledTimes(2); // One with data, one without
expect(mockLogger.info).toHaveBeenCalledWith('Migration completed successfully');
});
describe('when error occurs', () => {
const errorMessage = 'Test error message';
describe('during initialization', () => {
it('should handle abort error correctly', async () => {
runPromise = taskRunner.run({});
abortController.abort(); // Trigger the abort signal
await expect(runPromise).resolves.toBeUndefined(); // Ensure the function handles abort gracefully
expect(mockLogger.info).toHaveBeenCalledWith(
'Abort signal received, stopping initialization'
);
});
it('should handle other errors correctly', async () => {
mockRetrieverInitialize.mockRejectedValueOnce(new Error(errorMessage));
runPromise = taskRunner.run({});
await expect(runPromise).resolves.toBeUndefined(); // Ensure the function handles abort gracefully
expect(mockLogger.error).toHaveBeenCalledWith(
`Error initializing migration: Error: ${errorMessage}`
);
});
});
describe('during migration', () => {
beforeEach(() => {
mockRuleMigrationsDataClient.rules.get.mockRestore();
mockRuleMigrationsDataClient.rules.get
.mockResolvedValue({ total: 0, data: [] })
.mockResolvedValueOnce({
total: 1,
data: [{ id: ruleId, status: SiemMigrationStatus.PENDING }] as StoredRuleMigration[],
});
});
it('should handle abort error correctly', async () => {
runPromise = taskRunner.run({});
await Promise.resolve(); // Wait for the initialization to complete
abortController.abort(); // Trigger the abort signal
await expect(runPromise).resolves.toBeUndefined(); // Ensure the function handles abort gracefully
expect(mockLogger.info).toHaveBeenCalledWith('Abort signal received, stopping migration');
expect(mockRuleMigrationsDataClient.rules.releaseProcessing).toHaveBeenCalled();
});
it('should handle other errors correctly', async () => {
mockInvoke.mockRejectedValue(new Error(errorMessage));
runPromise = taskRunner.run({});
await expect(runPromise).resolves.toBeUndefined();
expect(mockLogger.error).toHaveBeenCalledWith(
`Error translating rule \"${ruleId}\" with error: ${errorMessage}`
);
expect(mockRuleMigrationsDataClient.rules.saveError).toHaveBeenCalled();
});
describe('during rate limit errors', () => {
const rule2Id = 'test-rule-id-2';
const error = new Error('429. You did way too many requests to this random LLM API bud');
beforeEach(async () => {
mockRuleMigrationsDataClient.rules.get.mockRestore();
mockRuleMigrationsDataClient.rules.get
.mockResolvedValue({ total: 0, data: [] })
.mockResolvedValueOnce({
total: 2,
data: [
{ id: ruleId, status: SiemMigrationStatus.PENDING },
{ id: rule2Id, status: SiemMigrationStatus.PENDING },
] as StoredRuleMigration[],
});
});
it('should retry with exponential backoff', async () => {
mockInvoke
.mockResolvedValue({}) // Successful calls from here on
.mockRejectedValueOnce(error) // First failed call for rule 1
.mockRejectedValueOnce(error) // First failed call for rule 2
.mockRejectedValueOnce(error) // Second failed call for rule 1
.mockRejectedValueOnce(error); // Third failed call for rule 1
await expect(taskRunner.run({})).resolves.toBeUndefined(); // success
/**
* Invoke calls:
* rule 1 -> failure -> start backoff retries
* rule 2 -> failure -> await for rule 1 backoff
* then:
* rule 1 retry 1 -> failure
* rule 1 retry 2 -> failure
* rule 1 retry 3 -> success
* then:
* rule 2 -> success
*/
expect(mockInvoke).toHaveBeenCalledTimes(6);
expect(mockTimeout).toHaveBeenCalledTimes(6); // 3 backoff sleeps + 3 execution sleeps
expect(mockTimeout).toHaveBeenNthCalledWith(
1,
expect.any(Function),
expect.any(Number)
);
expect(mockTimeout).toHaveBeenNthCalledWith(
2,
expect.any(Function),
expect.any(Number)
);
expect(mockTimeout).toHaveBeenNthCalledWith(3, expect.any(Function), 1000);
expect(mockTimeout).toHaveBeenNthCalledWith(4, expect.any(Function), 2000);
expect(mockTimeout).toHaveBeenNthCalledWith(5, expect.any(Function), 4000);
expect(mockTimeout).toHaveBeenNthCalledWith(
6,
expect.any(Function),
expect.any(Number)
);
expect(mockLogger.debug).toHaveBeenCalledWith(
`Awaiting backoff task for rule "${rule2Id}"`
);
expect(mockInvoke).toHaveBeenCalledTimes(6); // 3 retries + 3 executions
expect(mockRuleMigrationsDataClient.rules.saveCompleted).toHaveBeenCalledTimes(2); // 2 rules
});
it('should fail when reached maxRetries', async () => {
mockInvoke.mockRejectedValue(error);
await expect(taskRunner.run({})).resolves.toBeUndefined(); // success
// maxRetries = 8
expect(mockInvoke).toHaveBeenCalledTimes(10); // 8 retries + 2 executions
expect(mockTimeout).toHaveBeenCalledTimes(10); // 8 backoff sleeps + 2 execution sleeps
expect(mockRuleMigrationsDataClient.rules.saveError).toHaveBeenCalledTimes(2); // 2 rules
});
it('should fail when reached max recovery attempts', async () => {
const rule3Id = 'test-rule-id-3';
const rule4Id = 'test-rule-id-4';
mockRuleMigrationsDataClient.rules.get.mockRestore();
mockRuleMigrationsDataClient.rules.get
.mockResolvedValue({ total: 0, data: [] })
.mockResolvedValueOnce({
total: 4,
data: [
{ id: ruleId, status: SiemMigrationStatus.PENDING },
{ id: rule2Id, status: SiemMigrationStatus.PENDING },
{ id: rule3Id, status: SiemMigrationStatus.PENDING },
{ id: rule4Id, status: SiemMigrationStatus.PENDING },
] as StoredRuleMigration[],
});
// max recovery attempts = 3
mockInvoke
.mockResolvedValue({}) // should never reach this
.mockRejectedValueOnce(error) // 1st failed call for rule 1
.mockRejectedValueOnce(error) // 1st failed call for rule 2
.mockRejectedValueOnce(error) // 1st failed call for rule 3
.mockRejectedValueOnce(error) // 1st failed call for rule 4
.mockResolvedValueOnce({}) // Successful call for the rule 1 backoff
.mockRejectedValueOnce(error) // 2nd failed call for the rule 2 recover
.mockRejectedValueOnce(error) // 2nd failed call for the rule 3 recover
.mockRejectedValueOnce(error) // 2nd failed call for the rule 4 recover
.mockResolvedValueOnce({}) // Successful call for the rule 2 backoff
.mockRejectedValueOnce(error) // 3rd failed call for the rule 3 recover
.mockRejectedValueOnce(error) // 3rd failed call for the rule 4 recover
.mockResolvedValueOnce({}) // Successful call for the rule 3 backoff
.mockRejectedValueOnce(error); // 4th failed call for the rule 4 recover (max attempts failure)
await expect(taskRunner.run({})).resolves.toBeUndefined(); // success
expect(mockRuleMigrationsDataClient.rules.saveCompleted).toHaveBeenCalledTimes(3); // rules 1, 2 and 3
expect(mockRuleMigrationsDataClient.rules.saveError).toHaveBeenCalledTimes(1); // rule 4
});
it('should increase the executor sleep time when rate limited', async () => {
const getResponse = {
total: 1,
data: [{ id: ruleId, status: SiemMigrationStatus.PENDING }] as StoredRuleMigration[],
};
mockRuleMigrationsDataClient.rules.get.mockRestore();
mockRuleMigrationsDataClient.rules.get
.mockResolvedValue({ total: 0, data: [] })
.mockResolvedValueOnce(getResponse)
.mockResolvedValueOnce({ total: 0, data: [] })
.mockResolvedValueOnce(getResponse)
.mockResolvedValueOnce({ total: 0, data: [] })
.mockResolvedValueOnce(getResponse)
.mockResolvedValueOnce({ total: 0, data: [] })
.mockResolvedValueOnce(getResponse)
.mockResolvedValueOnce({ total: 0, data: [] })
.mockResolvedValueOnce(getResponse)
.mockResolvedValueOnce({ total: 0, data: [] })
.mockResolvedValueOnce(getResponse)
.mockResolvedValueOnce({ total: 0, data: [] });
/**
* Current EXECUTOR_SLEEP settings:
* initialValueSeconds: 3, multiplier: 2, limitSeconds: 96, // 1m36s (5 increases)
*/
// @ts-expect-error (checking private properties)
expect(taskRunner.executorSleepMultiplier).toBe(3);
mockInvoke.mockResolvedValue({}).mockRejectedValueOnce(error); // rate limit and recovery
await expect(taskRunner.run({})).resolves.toBeUndefined(); // success
// @ts-expect-error (checking private properties)
expect(taskRunner.executorSleepMultiplier).toBe(6);
mockInvoke.mockResolvedValue({}).mockRejectedValueOnce(error); // rate limit and recovery
await expect(taskRunner.run({})).resolves.toBeUndefined(); // success
// @ts-expect-error (checking private properties)
expect(taskRunner.executorSleepMultiplier).toBe(12);
mockInvoke.mockResolvedValue({}).mockRejectedValueOnce(error); // rate limit and recovery
await expect(taskRunner.run({})).resolves.toBeUndefined(); // success
// @ts-expect-error (checking private properties)
expect(taskRunner.executorSleepMultiplier).toBe(24);
mockInvoke.mockResolvedValue({}).mockRejectedValueOnce(error); // rate limit and recovery
await expect(taskRunner.run({})).resolves.toBeUndefined(); // success
// @ts-expect-error (checking private properties)
expect(taskRunner.executorSleepMultiplier).toBe(48);
mockInvoke.mockResolvedValue({}).mockRejectedValueOnce(error); // rate limit and recovery
await expect(taskRunner.run({})).resolves.toBeUndefined(); // success
// @ts-expect-error (checking private properties)
expect(taskRunner.executorSleepMultiplier).toBe(96);
mockInvoke.mockResolvedValue({}).mockRejectedValueOnce(error); // rate limit and recovery
await expect(taskRunner.run({})).resolves.toBeUndefined(); // success
// @ts-expect-error (checking private properties)
expect(taskRunner.executorSleepMultiplier).toBe(96); // limit reached
expect(mockLogger.warn).toHaveBeenCalledWith(
'Executor sleep reached the maximum value'
);
});
});
});
});
});
});

View file

@ -0,0 +1,351 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import assert from 'assert';
import type { AuthenticatedUser, Logger } from '@kbn/core/server';
import { abortSignalToPromise, AbortError } from '@kbn/kibana-utils-plugin/server';
import type { RunnableConfig } from '@langchain/core/runnables';
import { SiemMigrationStatus } from '../../../../../common/siem_migrations/constants';
import { initPromisePool } from '../../../../utils/promise_pool';
import type { RuleMigrationsDataClient } from '../data/rule_migrations_data_client';
import type { MigrateRuleState } from './agent/types';
import { getRuleMigrationAgent } from './agent';
import { RuleMigrationsRetriever } from './retrievers';
import { SiemMigrationTelemetryClient } from './rule_migrations_telemetry_client';
import type { MigrationAgent } from './types';
import { generateAssistantComment } from './util/comments';
import type { SiemRuleMigrationsClientDependencies, StoredRuleMigration } from '../types';
import { ActionsClientChat } from './util/actions_client_chat';
import { EsqlKnowledgeBase } from './util/esql_knowledge_base';
/** Number of concurrent rule translations in the pool */
const TASK_CONCURRENCY = 10 as const;
/** Number of rules loaded in memory to be translated in the pool */
const TASK_BATCH_SIZE = 100 as const;
/** Exponential backoff configuration to handle rate limit errors */
const RETRY_CONFIG = {
initialRetryDelaySeconds: 1,
backoffMultiplier: 2,
maxRetries: 8,
// max waiting time 4m15s (1*2^8 = 256s)
} as const;
/** Executor sleep configuration
* A sleep time applied at the beginning of each single rule translation in the execution pool,
* The objective of this sleep is to spread the load of concurrent translations, and prevent hitting the rate limit repeatedly.
* The sleep time applied is a random number between [0-value]. Every time we hit rate limit the value is increased by the multiplier, up to the limit.
*/
const EXECUTOR_SLEEP = {
initialValueSeconds: 3,
multiplier: 2,
limitSeconds: 96, // 1m36s (5 increases)
} as const;
/** This limit should never be reached, it's a safety net to prevent infinite loops.
* It represents the max number of consecutive rate limit recovery & failure attempts.
* This can only happen when the API can not process TASK_CONCURRENCY translations at a time,
* even after the executor sleep is increased on every attempt.
**/
const EXECUTOR_RECOVER_MAX_ATTEMPTS = 3 as const;
export class RuleMigrationTaskRunner {
private telemetry?: SiemMigrationTelemetryClient;
private agent?: MigrationAgent;
private retriever?: RuleMigrationsRetriever;
private actionsClientChat: ActionsClientChat;
private abort: ReturnType<typeof abortSignalToPromise>;
private executorSleepMultiplier: number = EXECUTOR_SLEEP.initialValueSeconds;
public isWaiting: boolean = false;
constructor(
public readonly migrationId: string,
public readonly startedBy: AuthenticatedUser,
public readonly abortController: AbortController,
private readonly data: RuleMigrationsDataClient,
private readonly logger: Logger,
private readonly dependencies: SiemRuleMigrationsClientDependencies
) {
this.actionsClientChat = new ActionsClientChat(this.dependencies.actionsClient, this.logger);
this.abort = abortSignalToPromise(this.abortController.signal);
}
/** Retrieves the connector and creates the migration agent */
public async setup(connectorId: string) {
const { rulesClient, savedObjectsClient, inferenceClient } = this.dependencies;
const model = await this.actionsClientChat.createModel({
connectorId,
migrationId: this.migrationId,
abortController: this.abortController,
});
const esqlKnowledgeBase = new EsqlKnowledgeBase(
connectorId,
this.migrationId,
inferenceClient,
this.logger
);
this.retriever = new RuleMigrationsRetriever(this.migrationId, {
data: this.data,
rules: rulesClient,
savedObjects: savedObjectsClient,
});
this.telemetry = new SiemMigrationTelemetryClient(
this.dependencies.telemetry,
this.logger,
this.migrationId,
model.model
);
this.agent = getRuleMigrationAgent({
model,
esqlKnowledgeBase,
ruleMigrationsRetriever: this.retriever,
telemetryClient: this.telemetry,
logger: this.logger,
});
}
/** Initializes the retriever populating ELSER indices. It may take a few minutes */
private async initialize() {
assert(this.retriever, 'setup() must be called before initialize()');
await this.retriever.initialize();
}
public async run(invocationConfig: RunnableConfig): Promise<void> {
assert(this.telemetry, 'telemetry is missing please call setup() first');
const { telemetry, migrationId } = this;
const migrationTaskTelemetry = telemetry.startSiemMigrationTask();
try {
// TODO: track the duration of the initialization alone in the telemetry
this.logger.debug('Initializing migration');
await this.withAbort(this.initialize()); // long running operation
} catch (error) {
migrationTaskTelemetry.failure(error);
if (error instanceof AbortError) {
this.logger.info('Abort signal received, stopping initialization');
return;
} else {
this.logger.error(`Error initializing migration: ${error}`);
return;
}
}
const migrateRuleTask = this.createMigrateRuleTask(invocationConfig);
this.logger.debug(`Started rule translations. Concurrency is: ${TASK_CONCURRENCY}`);
try {
do {
const { data: ruleMigrations } = await this.data.rules.get(migrationId, {
filters: { status: SiemMigrationStatus.PENDING },
size: TASK_BATCH_SIZE, // keep these rules in memory and process them in the promise pool with concurrency limit
});
if (ruleMigrations.length === 0) {
break;
}
this.logger.debug(`Start processing batch of ${ruleMigrations.length} rules`);
const { errors } = await initPromisePool<StoredRuleMigration, void, Error>({
concurrency: TASK_CONCURRENCY,
abortSignal: this.abortController.signal,
items: ruleMigrations,
executor: async (ruleMigration) => {
const ruleTranslationTelemetry = migrationTaskTelemetry.startRuleTranslation();
try {
await this.saveRuleProcessing(ruleMigration);
const migrationResult = await migrateRuleTask(ruleMigration);
await this.saveRuleCompleted(ruleMigration, migrationResult);
ruleTranslationTelemetry.success(migrationResult);
} catch (error) {
if (error instanceof AbortError) {
throw error;
}
ruleTranslationTelemetry.failure(error);
await this.saveRuleFailed(ruleMigration, error);
}
},
});
if (errors.length > 0) {
throw errors[0].error; // Only AbortError is thrown from the pool. The task was aborted
}
this.logger.debug('Batch processed successfully');
} while (true);
migrationTaskTelemetry.success();
this.logger.info('Migration completed successfully');
} catch (error) {
await this.data.rules.releaseProcessing(migrationId);
migrationTaskTelemetry.failure(error);
if (error instanceof AbortError) {
this.logger.info('Abort signal received, stopping migration');
return;
} else {
this.logger.error(`Error processing migration: ${error}`);
}
} finally {
this.abort.cleanup();
}
}
private createMigrateRuleTask(invocationConfig: RunnableConfig) {
assert(this.agent, 'agent is missing please call setup() first');
const { agent } = this;
const config: RunnableConfig = {
...invocationConfig,
// signal: abortController.signal, // not working properly https://github.com/langchain-ai/langgraphjs/issues/319
};
const invoke = async (migrationRule: StoredRuleMigration): Promise<MigrateRuleState> => {
// using withAbort in the agent invocation is not ideal but is a workaround for the issue with the langGraph signal not working properly
return this.withAbort<MigrateRuleState>(
agent.invoke({ original_rule: migrationRule.original_rule }, config)
);
};
// Invokes the rule translation with exponential backoff, should be called only when the rate limit has been hit
const invokeWithBackoff = async (
migrationRule: StoredRuleMigration
): Promise<MigrateRuleState> => {
this.logger.debug(`Rate limit backoff started for rule "${migrationRule.id}"`);
let retriesLeft: number = RETRY_CONFIG.maxRetries;
while (true) {
try {
await this.sleepRetry(retriesLeft);
retriesLeft--;
const result = await invoke(migrationRule);
this.logger.info(
`Rate limit backoff completed successfully for rule "${migrationRule.id}" after ${
RETRY_CONFIG.maxRetries - retriesLeft
} retries`
);
return result;
} catch (error) {
if (!this.isRateLimitError(error) || retriesLeft === 0) {
this.logger.debug(
`Rate limit backoff completed unsuccessfully for rule "${migrationRule.id}"`
);
const logMessage =
retriesLeft === 0
? `Rate limit backoff completed unsuccessfully for rule "${migrationRule.id}"`
: `Rate limit backoff interrupted for rule "${migrationRule.id}". ${error} `;
this.logger.debug(logMessage);
throw error;
}
this.logger.debug(
`Rate limit backoff not completed for rule "${migrationRule.id}", retries left: ${retriesLeft}`
);
}
}
};
let backoffPromise: Promise<MigrateRuleState> | undefined;
// Migrates one rule, this function will be called concurrently by the promise pool.
// Handles rate limit errors and ensures only one task is executing the backoff retries at a time, the rest of translation will await.
const migrateRule = async (migrationRule: StoredRuleMigration): Promise<MigrateRuleState> => {
let recoverAttemptsLeft: number = EXECUTOR_RECOVER_MAX_ATTEMPTS;
while (true) {
try {
await this.executorSleep(); // Random sleep, increased every time we hit the rate limit.
return await invoke(migrationRule);
} catch (error) {
if (!this.isRateLimitError(error) || recoverAttemptsLeft === 0) {
throw error;
}
if (!backoffPromise) {
// only one translation handles the rate limit backoff retries, the rest will await it and try again when it's resolved
backoffPromise = invokeWithBackoff(migrationRule);
this.isWaiting = true;
return backoffPromise.finally(() => {
backoffPromise = undefined;
this.increaseExecutorSleep();
this.isWaiting = false;
});
}
this.logger.debug(`Awaiting backoff task for rule "${migrationRule.id}"`);
await backoffPromise.catch(() => {
throw error; // throw the original error
});
recoverAttemptsLeft--;
}
}
};
return migrateRule;
}
private isRateLimitError(error: Error) {
return error.message.match(/\b429\b/); // "429" (whole word in the error message): Too Many Requests.
}
private async withAbort<T>(promise: Promise<T>): Promise<T> {
return Promise.race([promise, this.abort.promise]);
}
private async sleep(seconds: number) {
await this.withAbort(new Promise((resolve) => setTimeout(resolve, seconds * 1000)));
}
// Exponential backoff implementation
private async sleepRetry(retriesLeft: number) {
const seconds =
RETRY_CONFIG.initialRetryDelaySeconds *
Math.pow(RETRY_CONFIG.backoffMultiplier, RETRY_CONFIG.maxRetries - retriesLeft);
this.logger.debug(`Retry sleep: ${seconds}s`);
await this.sleep(seconds);
}
private executorSleep = async () => {
const seconds = Math.random() * this.executorSleepMultiplier;
this.logger.debug(`Executor sleep: ${seconds.toFixed(3)}s`);
await this.sleep(seconds);
};
private increaseExecutorSleep = () => {
const increasedMultiplier = this.executorSleepMultiplier * EXECUTOR_SLEEP.multiplier;
if (increasedMultiplier > EXECUTOR_SLEEP.limitSeconds) {
this.logger.warn('Executor sleep reached the maximum value');
return;
}
this.executorSleepMultiplier = increasedMultiplier;
};
private async saveRuleProcessing(ruleMigration: StoredRuleMigration) {
this.logger.debug(`Starting translation of rule "${ruleMigration.id}"`);
return this.data.rules.saveProcessing(ruleMigration.id);
}
private async saveRuleCompleted(
ruleMigration: StoredRuleMigration,
migrationResult: MigrateRuleState
) {
this.logger.debug(`Translation of rule "${ruleMigration.id}" succeeded`);
const ruleMigrationTranslated = {
...ruleMigration,
elastic_rule: migrationResult.elastic_rule,
translation_result: migrationResult.translation_result,
comments: migrationResult.comments,
};
return this.data.rules.saveCompleted(ruleMigrationTranslated);
}
private async saveRuleFailed(ruleMigration: StoredRuleMigration, error: Error) {
this.logger.error(`Error translating rule "${ruleMigration.id}" with error: ${error.message}`);
const comments = [generateAssistantComment(`Error migrating rule: ${error.message}`)];
return this.data.rules.saveError({ ...ruleMigration, comments });
}
}

View file

@ -7,12 +7,10 @@
import type { Logger } from '@kbn/core/server';
import type { RuleMigrationTaskCreateClientParams } from './types';
import { RuleMigrationsTaskClient } from './rule_migrations_task_client';
export type MigrationRunning = Map<string, { user: string; abortController: AbortController }>;
import { RuleMigrationsTaskClient, type MigrationsRunning } from './rule_migrations_task_client';
export class RuleMigrationsTaskService {
private migrationsRunning: MigrationRunning;
private migrationsRunning: MigrationsRunning;
constructor(private logger: Logger) {
this.migrationsRunning = new Map();

View file

@ -18,7 +18,6 @@ const translationResultWithMatchMock = {
const translationResultMock = {
translation_result: 'partial',
} as MigrateRuleState;
const stats = { completed: 2, failed: 2 };
const preFilterRulesMock: RuleMigrationPrebuiltRule[] = [
{
rule_id: 'rule1id',
@ -96,13 +95,22 @@ describe('siemMigrationTelemetry', () => {
jest.useRealTimers();
});
it('start/end migration with error', async () => {
const endSiemMigration = siemTelemetryClient.startSiemMigration();
const error = new Error('test');
endSiemMigration({ stats, error });
const error = 'test error message';
const siemMigrationTaskTelemetry = siemTelemetryClient.startSiemMigrationTask();
const ruleTranslationTelemetry = siemMigrationTaskTelemetry.startRuleTranslation();
// 2 success and 2 failures
ruleTranslationTelemetry.success(translationResultMock);
ruleTranslationTelemetry.success(translationResultMock);
ruleTranslationTelemetry.failure(new Error('test'));
ruleTranslationTelemetry.failure(new Error('test'));
siemMigrationTaskTelemetry.failure(new Error(error));
expect(mockTelemetry.reportEvent).toHaveBeenCalledWith('siem_migrations_migration_failure', {
completed: 2,
duration: 0,
error: 'test',
error,
failed: 2,
migrationId: 'testmigration',
model: 'testModel',
@ -110,9 +118,17 @@ describe('siemMigrationTelemetry', () => {
});
});
it('start/end migration success', async () => {
const endSiemMigration = siemTelemetryClient.startSiemMigration();
const siemMigrationTaskTelemetry = siemTelemetryClient.startSiemMigrationTask();
const ruleTranslationTelemetry = siemMigrationTaskTelemetry.startRuleTranslation();
// 2 success and 2 failures
ruleTranslationTelemetry.success(translationResultMock);
ruleTranslationTelemetry.success(translationResultMock);
ruleTranslationTelemetry.failure(new Error('test'));
ruleTranslationTelemetry.failure(new Error('test'));
siemMigrationTaskTelemetry.success();
endSiemMigration({ stats });
expect(mockTelemetry.reportEvent).toHaveBeenCalledWith('siem_migrations_migration_success', {
completed: 2,
duration: 0,
@ -123,23 +139,23 @@ describe('siemMigrationTelemetry', () => {
});
});
it('start/end rule translation with error', async () => {
const endRuleTranslation = siemTelemetryClient.startRuleTranslation();
const error = new Error('test');
const error = 'test error message';
const siemMigrationTaskTelemetry = siemTelemetryClient.startSiemMigrationTask();
const ruleTranslationTelemetry = siemMigrationTaskTelemetry.startRuleTranslation();
ruleTranslationTelemetry.failure(new Error(error));
endRuleTranslation({ error });
expect(mockTelemetry.reportEvent).toHaveBeenCalledWith(
'siem_migrations_rule_translation_failure',
{
error: 'test',
migrationId: 'testmigration',
model: 'testModel',
}
{ error, migrationId: 'testmigration', model: 'testModel' }
);
});
it('start/end rule translation success with prebuilt', async () => {
const endRuleTranslation = siemTelemetryClient.startRuleTranslation();
const siemMigrationTaskTelemetry = siemTelemetryClient.startSiemMigrationTask();
const ruleTranslationTelemetry = siemMigrationTaskTelemetry.startRuleTranslation();
ruleTranslationTelemetry.success(translationResultWithMatchMock);
endRuleTranslation({ migrationResult: translationResultWithMatchMock });
expect(mockTelemetry.reportEvent).toHaveBeenCalledWith(
'siem_migrations_rule_translation_success',
{
@ -152,9 +168,11 @@ describe('siemMigrationTelemetry', () => {
);
});
it('start/end rule translation success without prebuilt', async () => {
const endRuleTranslation = siemTelemetryClient.startRuleTranslation();
const siemMigrationTaskTelemetry = siemTelemetryClient.startSiemMigrationTask();
const ruleTranslationTelemetry = siemMigrationTaskTelemetry.startRuleTranslation();
ruleTranslationTelemetry.success(translationResultMock);
endRuleTranslation({ migrationResult: translationResultMock });
expect(mockTelemetry.reportEvent).toHaveBeenCalledWith(
'siem_migrations_rule_translation_success',
{

View file

@ -27,19 +27,6 @@ interface PrebuiltRuleMatchEvent {
postFilterRule?: RuleSemanticSearchResult;
}
interface RuleTranslationEvent {
error?: Error;
migrationResult?: MigrateRuleState;
}
interface SiemMigrationEvent {
error?: Error;
stats: {
failed: number;
completed: number;
};
}
export class SiemMigrationTelemetryClient {
constructor(
private readonly telemetry: AnalyticsServiceSetup,
@ -69,6 +56,7 @@ export class SiemMigrationTelemetryClient {
postFilterIntegrationCount: postFilterIntegration ? 1 : 0,
});
}
public reportPrebuiltRulesMatch({
preFilterRules,
postFilterRule,
@ -82,60 +70,58 @@ export class SiemMigrationTelemetryClient {
postFilterRuleCount: postFilterRule ? 1 : 0,
});
}
public startRuleTranslation(): (
args: Pick<RuleTranslationEvent, 'error' | 'migrationResult'>
) => void {
public startSiemMigrationTask() {
const startTime = Date.now();
const stats = { completed: 0, failed: 0 };
return ({ error, migrationResult }) => {
const duration = Date.now() - startTime;
if (error) {
this.reportEvent(SIEM_MIGRATIONS_RULE_TRANSLATION_FAILURE, {
return {
startRuleTranslation: () => {
const ruleStartTime = Date.now();
return {
success: (migrationResult: MigrateRuleState) => {
stats.completed++;
this.reportEvent(SIEM_MIGRATIONS_RULE_TRANSLATION_SUCCESS, {
migrationId: this.migrationId,
translationResult: migrationResult.translation_result || '',
duration: Date.now() - ruleStartTime,
model: this.modelName,
prebuiltMatch: migrationResult.elastic_rule?.prebuilt_rule_id ? true : false,
});
},
failure: (error: Error) => {
stats.failed++;
this.reportEvent(SIEM_MIGRATIONS_RULE_TRANSLATION_FAILURE, {
migrationId: this.migrationId,
error: error.message,
model: this.modelName,
});
},
};
},
success: () => {
const duration = Date.now() - startTime;
this.reportEvent(SIEM_MIGRATIONS_MIGRATION_SUCCESS, {
migrationId: this.migrationId,
error: error.message,
model: this.modelName,
model: this.modelName || '',
completed: stats.completed,
failed: stats.failed,
total: stats.completed + stats.failed,
duration,
});
return;
}
this.reportEvent(SIEM_MIGRATIONS_RULE_TRANSLATION_SUCCESS, {
migrationId: this.migrationId,
translationResult: migrationResult?.translation_result || '',
duration,
model: this.modelName,
prebuiltMatch: migrationResult?.elastic_rule?.prebuilt_rule_id ? true : false,
});
};
}
public startSiemMigration(): (args: Pick<SiemMigrationEvent, 'error' | 'stats'>) => void {
const startTime = Date.now();
return ({ error, stats }) => {
const duration = Date.now() - startTime;
const total = stats ? stats.completed + stats.failed : 0;
if (error) {
},
failure: (error: Error) => {
const duration = Date.now() - startTime;
this.reportEvent(SIEM_MIGRATIONS_MIGRATION_FAILURE, {
migrationId: this.migrationId,
model: this.modelName || '',
completed: stats ? stats.completed : 0,
failed: stats ? stats.failed : 0,
total,
completed: stats.completed,
failed: stats.failed,
total: stats.completed + stats.failed,
duration,
error: error.message,
});
return;
}
this.reportEvent(SIEM_MIGRATIONS_MIGRATION_SUCCESS, {
migrationId: this.migrationId,
model: this.modelName || '',
completed: stats ? stats.completed : 0,
failed: stats ? stats.failed : 0,
total,
duration,
});
},
};
}
}

View file

@ -12,6 +12,7 @@ import type { SiemRuleMigrationsClientDependencies } from '../types';
import type { getRuleMigrationAgent } from './agent';
import type { SiemMigrationTelemetryClient } from './rule_migrations_telemetry_client';
import type { ChatModel } from './util/actions_client_chat';
import type { RuleMigrationsRetriever } from './retrievers';
export type MigrationAgent = ReturnType<typeof getRuleMigrationAgent>;
@ -32,7 +33,9 @@ export interface RuleMigrationTaskRunParams extends RuleMigrationTaskStartParams
abortController: AbortController;
}
export interface RuleMigrationTaskCreateAgentParams extends RuleMigrationTaskStartParams {
export interface RuleMigrationTaskCreateAgentParams {
connectorId: string;
retriever: RuleMigrationsRetriever;
telemetryClient: SiemMigrationTelemetryClient;
model: ChatModel;
}

View file

@ -17,6 +17,7 @@ import type { CustomChatModelInput as ActionsClientBedrockChatModelParams } from
import type { ActionsClientChatOpenAIParams } from '@kbn/langchain/server/language_models/chat_openai';
import type { CustomChatModelInput as ActionsClientChatVertexAIParams } from '@kbn/langchain/server/language_models/gemini_chat';
import type { CustomChatModelInput as ActionsClientSimpleChatModelParams } from '@kbn/langchain/server/language_models/simple_chat_model';
import { TELEMETRY_SIEM_MIGRATION_ID } from './constants';
export type ChatModel =
| ActionsClientSimpleChatModel
@ -42,17 +43,23 @@ const llmTypeDictionary: Record<string, string> = {
[`.inference`]: `inference`,
};
export class ActionsClientChat {
constructor(
private readonly connectorId: string,
private readonly actionsClient: ActionsClient,
private readonly logger: Logger
) {}
interface CreateModelParams {
migrationId: string;
connectorId: string;
abortController: AbortController;
}
public async createModel(params?: ChatModelParams): Promise<ChatModel> {
const connector = await this.actionsClient.get({ id: this.connectorId });
export class ActionsClientChat {
constructor(private readonly actionsClient: ActionsClient, private readonly logger: Logger) {}
public async createModel({
migrationId,
connectorId,
abortController,
}: CreateModelParams): Promise<ChatModel> {
const connector = await this.actionsClient.get({ id: connectorId });
if (!connector) {
throw new Error(`Connector not found: ${this.connectorId}`);
throw new Error(`Connector not found: ${connectorId}`);
}
const llmType = this.getLLMType(connector.actionTypeId);
@ -60,12 +67,15 @@ export class ActionsClientChat {
const model = new ChatModelClass({
actionsClient: this.actionsClient,
connectorId: this.connectorId,
logger: this.logger,
connectorId,
llmType,
model: connector.config?.defaultModel,
...params,
streaming: false, // disabling streaming by default
streaming: false,
temperature: 0.05,
maxRetries: 1, // Only retry once inside the model, we will handle backoff retries in the task runner
telemetryMetadata: { pluginId: TELEMETRY_SIEM_MIGRATION_ID, aggregateBy: migrationId },
signal: abortController.signal,
logger: this.logger,
});
return model;
}