mirror of
https://github.com/elastic/kibana.git
synced 2025-04-25 02:09:32 -04:00
# Backport This will backport the following commits from `main` to `8.12`: - [[Obs AI Assistant] Add guardrails (#174060)](https://github.com/elastic/kibana/pull/174060) <!--- Backport version: 9.4.3 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) <!--BACKPORT [{"author":{"name":"Dario Gieselaar","email":"dario.gieselaar@elastic.co"},"sourceCommit":{"committedDate":"2024-01-02T18:35:49Z","message":"[Obs AI Assistant] Add guardrails (#174060)\n\nAdd guardrails against function looping (max 3 calls in a completion\r\nrequest) and long function responses (max 4000 tokens).\r\n\r\n---------\r\n\r\nCo-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>","sha":"725a21727af57b11bdd709966d20f71f9c9be16c","branchLabelMapping":{"^v8.13.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v8.12.0","v8.12.1","v8.13.0"],"title":"[Obs AI Assistant] Add guardrails","number":174060,"url":"https://github.com/elastic/kibana/pull/174060","mergeCommit":{"message":"[Obs AI Assistant] Add guardrails (#174060)\n\nAdd guardrails against function looping (max 3 calls in a completion\r\nrequest) and long function responses (max 4000 tokens).\r\n\r\n---------\r\n\r\nCo-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>","sha":"725a21727af57b11bdd709966d20f71f9c9be16c"}},"sourceBranch":"main","suggestedTargetBranches":["8.12"],"targetPullRequestStates":[{"branch":"8.12","label":"v8.12.0","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"},{"branch":"main","label":"v8.13.0","branchLabelMappingKey":"^v8.13.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/174060","number":174060,"mergeCommit":{"message":"[Obs AI Assistant] Add guardrails (#174060)\n\nAdd guardrails against function looping (max 3 calls in a completion\r\nrequest) and long function responses (max 4000 tokens).\r\n\r\n---------\r\n\r\nCo-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>","sha":"725a21727af57b11bdd709966d20f71f9c9be16c"}}]}] BACKPORT--> Co-authored-by: Dario Gieselaar <dario.gieselaar@elastic.co>
This commit is contained in:
parent
fdd0dc4ce5
commit
a38fc92c30
2 changed files with 281 additions and 36 deletions
|
@ -7,9 +7,11 @@
|
||||||
import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client';
|
import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client';
|
||||||
import type { ElasticsearchClient, Logger } from '@kbn/core/server';
|
import type { ElasticsearchClient, Logger } from '@kbn/core/server';
|
||||||
import type { DeeplyMockedKeys } from '@kbn/utility-types-jest';
|
import type { DeeplyMockedKeys } from '@kbn/utility-types-jest';
|
||||||
import { merge } from 'lodash';
|
import { waitFor } from '@testing-library/react';
|
||||||
|
import { last, merge, repeat } from 'lodash';
|
||||||
|
import { ChatCompletionResponseMessage } from 'openai';
|
||||||
import { Subject } from 'rxjs';
|
import { Subject } from 'rxjs';
|
||||||
import { PassThrough, type Readable } from 'stream';
|
import { EventEmitter, PassThrough, type Readable } from 'stream';
|
||||||
import { finished } from 'stream/promises';
|
import { finished } from 'stream/promises';
|
||||||
import { ObservabilityAIAssistantClient } from '.';
|
import { ObservabilityAIAssistantClient } from '.';
|
||||||
import { createResourceNamesMap } from '..';
|
import { createResourceNamesMap } from '..';
|
||||||
|
@ -70,7 +72,7 @@ function createLlmSimulator() {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
describe('Observability AI Assistant service', () => {
|
describe('Observability AI Assistant client', () => {
|
||||||
let client: ObservabilityAIAssistantClient;
|
let client: ObservabilityAIAssistantClient;
|
||||||
|
|
||||||
const actionsClientMock: DeeplyMockedKeys<ActionsClient> = {
|
const actionsClientMock: DeeplyMockedKeys<ActionsClient> = {
|
||||||
|
@ -84,14 +86,8 @@ describe('Observability AI Assistant service', () => {
|
||||||
} as any;
|
} as any;
|
||||||
|
|
||||||
const currentUserEsClientMock: DeeplyMockedKeys<ElasticsearchClient> = {
|
const currentUserEsClientMock: DeeplyMockedKeys<ElasticsearchClient> = {
|
||||||
search: jest.fn().mockResolvedValue({
|
search: jest.fn(),
|
||||||
hits: {
|
fieldCaps: jest.fn(),
|
||||||
hits: [],
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
fieldCaps: jest.fn().mockResolvedValue({
|
|
||||||
fields: [],
|
|
||||||
}),
|
|
||||||
} as any;
|
} as any;
|
||||||
|
|
||||||
const knowledgeBaseServiceMock: DeeplyMockedKeys<KnowledgeBaseService> = {
|
const knowledgeBaseServiceMock: DeeplyMockedKeys<KnowledgeBaseService> = {
|
||||||
|
@ -107,16 +103,29 @@ describe('Observability AI Assistant service', () => {
|
||||||
|
|
||||||
const functionClientMock: DeeplyMockedKeys<ChatFunctionClient> = {
|
const functionClientMock: DeeplyMockedKeys<ChatFunctionClient> = {
|
||||||
executeFunction: jest.fn(),
|
executeFunction: jest.fn(),
|
||||||
getFunctions: jest.fn().mockReturnValue([]),
|
getFunctions: jest.fn(),
|
||||||
hasFunction: jest.fn().mockImplementation((name) => {
|
hasFunction: jest.fn(),
|
||||||
return name !== 'recall';
|
|
||||||
}),
|
|
||||||
} as any;
|
} as any;
|
||||||
|
|
||||||
let llmSimulator: LlmSimulator;
|
let llmSimulator: LlmSimulator;
|
||||||
|
|
||||||
function createClient() {
|
function createClient() {
|
||||||
jest.clearAllMocks();
|
jest.resetAllMocks();
|
||||||
|
|
||||||
|
functionClientMock.getFunctions.mockReturnValue([]);
|
||||||
|
functionClientMock.hasFunction.mockImplementation((name) => {
|
||||||
|
return name !== 'recall';
|
||||||
|
});
|
||||||
|
|
||||||
|
currentUserEsClientMock.search.mockResolvedValue({
|
||||||
|
hits: {
|
||||||
|
hits: [],
|
||||||
|
},
|
||||||
|
} as any);
|
||||||
|
|
||||||
|
currentUserEsClientMock.fieldCaps.mockResolvedValue({
|
||||||
|
fields: [],
|
||||||
|
} as any);
|
||||||
|
|
||||||
return new ObservabilityAIAssistantClient({
|
return new ObservabilityAIAssistantClient({
|
||||||
actionsClient: actionsClientMock,
|
actionsClient: actionsClientMock,
|
||||||
|
@ -158,6 +167,10 @@ describe('Observability AI Assistant service', () => {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
jest.clearAllMocks();
|
||||||
|
});
|
||||||
|
|
||||||
describe('when completing a conversation without an initial conversation id', () => {
|
describe('when completing a conversation without an initial conversation id', () => {
|
||||||
let stream: Readable;
|
let stream: Readable;
|
||||||
|
|
||||||
|
@ -1148,4 +1161,197 @@ describe('Observability AI Assistant service', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('when the LLM keeps on calling a function and the limit has been exceeded', () => {
|
||||||
|
let stream: Readable;
|
||||||
|
|
||||||
|
let dataHandler: jest.Mock;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
client = createClient();
|
||||||
|
|
||||||
|
const onLlmCall = new EventEmitter();
|
||||||
|
|
||||||
|
function waitForNextLlmCall() {
|
||||||
|
return new Promise<void>((resolve) => onLlmCall.addListener('next', resolve));
|
||||||
|
}
|
||||||
|
|
||||||
|
actionsClientMock.execute.mockImplementation(async () => {
|
||||||
|
llmSimulator = createLlmSimulator();
|
||||||
|
onLlmCall.emit('next');
|
||||||
|
return {
|
||||||
|
actionId: '',
|
||||||
|
status: 'ok',
|
||||||
|
data: llmSimulator.stream,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
functionClientMock.getFunctions.mockImplementation(() => [
|
||||||
|
{
|
||||||
|
definition: {
|
||||||
|
name: 'get_top_alerts',
|
||||||
|
contexts: ['core'],
|
||||||
|
description: '',
|
||||||
|
parameters: {},
|
||||||
|
},
|
||||||
|
respond: async () => {
|
||||||
|
return { content: 'Call this function again' };
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
functionClientMock.hasFunction.mockImplementation((name) => name === 'get_top_alerts');
|
||||||
|
functionClientMock.executeFunction.mockImplementation(async () => ({
|
||||||
|
content: 'Call this function again',
|
||||||
|
}));
|
||||||
|
|
||||||
|
stream = await client.complete({
|
||||||
|
connectorId: 'foo',
|
||||||
|
messages: [system('This is a system message'), user('How many alerts do I have?')],
|
||||||
|
functionClient: functionClientMock,
|
||||||
|
signal: new AbortController().signal,
|
||||||
|
title: 'My predefined title',
|
||||||
|
persist: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
dataHandler = jest.fn();
|
||||||
|
|
||||||
|
stream.on('data', dataHandler);
|
||||||
|
|
||||||
|
async function requestAlertsFunctionCall() {
|
||||||
|
const body = JSON.parse(
|
||||||
|
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
|
||||||
|
);
|
||||||
|
|
||||||
|
let nextLlmCallPromise: Promise<void>;
|
||||||
|
|
||||||
|
if (body.functions?.length) {
|
||||||
|
nextLlmCallPromise = waitForNextLlmCall();
|
||||||
|
await llmSimulator.next({ function_call: { name: 'get_top_alerts' } });
|
||||||
|
} else {
|
||||||
|
nextLlmCallPromise = Promise.resolve();
|
||||||
|
await llmSimulator.next({ content: 'Looks like we are done here' });
|
||||||
|
}
|
||||||
|
|
||||||
|
await llmSimulator.complete();
|
||||||
|
|
||||||
|
await nextLlmCallPromise;
|
||||||
|
}
|
||||||
|
|
||||||
|
await requestAlertsFunctionCall();
|
||||||
|
|
||||||
|
await requestAlertsFunctionCall();
|
||||||
|
|
||||||
|
await requestAlertsFunctionCall();
|
||||||
|
|
||||||
|
await requestAlertsFunctionCall();
|
||||||
|
|
||||||
|
await finished(stream);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('executed the function no more than three times', () => {
|
||||||
|
expect(functionClientMock.executeFunction).toHaveBeenCalledTimes(3);
|
||||||
|
});
|
||||||
|
|
||||||
|
it('does not give the LLM the choice to call a function anymore', () => {
|
||||||
|
const firstBody = JSON.parse(
|
||||||
|
(actionsClientMock.execute.mock.calls[0][0].params as any).subActionParams.body
|
||||||
|
);
|
||||||
|
const body = JSON.parse(
|
||||||
|
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
|
||||||
|
);
|
||||||
|
|
||||||
|
expect(firstBody.functions.length).toBe(1);
|
||||||
|
|
||||||
|
expect(body.functions).toBeUndefined();
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe('when the function response exceeds the max no of tokens for one', () => {
|
||||||
|
let stream: Readable;
|
||||||
|
|
||||||
|
let dataHandler: jest.Mock;
|
||||||
|
|
||||||
|
beforeEach(async () => {
|
||||||
|
client = createClient();
|
||||||
|
|
||||||
|
let functionResponsePromiseResolve: Function | undefined;
|
||||||
|
|
||||||
|
actionsClientMock.execute.mockImplementation(async () => {
|
||||||
|
llmSimulator = createLlmSimulator();
|
||||||
|
return {
|
||||||
|
actionId: '',
|
||||||
|
status: 'ok',
|
||||||
|
data: llmSimulator.stream,
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
|
functionClientMock.getFunctions.mockImplementation(() => [
|
||||||
|
{
|
||||||
|
definition: {
|
||||||
|
name: 'get_top_alerts',
|
||||||
|
contexts: ['core'],
|
||||||
|
description: '',
|
||||||
|
parameters: {},
|
||||||
|
},
|
||||||
|
respond: async () => {
|
||||||
|
return { content: '' };
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
functionClientMock.hasFunction.mockImplementation((name) => name === 'get_top_alerts');
|
||||||
|
|
||||||
|
functionClientMock.executeFunction.mockImplementation(() => {
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
functionResponsePromiseResolve = resolve;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
stream = await client.complete({
|
||||||
|
connectorId: 'foo',
|
||||||
|
messages: [system('This is a system message'), user('How many alerts do I have?')],
|
||||||
|
functionClient: functionClientMock,
|
||||||
|
signal: new AbortController().signal,
|
||||||
|
title: 'My predefined title',
|
||||||
|
persist: true,
|
||||||
|
});
|
||||||
|
|
||||||
|
dataHandler = jest.fn();
|
||||||
|
|
||||||
|
stream.on('data', dataHandler);
|
||||||
|
|
||||||
|
await llmSimulator.next({ function_call: { name: 'get_top_alerts' } });
|
||||||
|
|
||||||
|
await llmSimulator.complete();
|
||||||
|
|
||||||
|
await waitFor(() => functionResponsePromiseResolve !== undefined);
|
||||||
|
|
||||||
|
functionResponsePromiseResolve!({
|
||||||
|
content: repeat('word ', 10000),
|
||||||
|
});
|
||||||
|
|
||||||
|
await waitFor(() => actionsClientMock.execute.mock.calls.length > 1);
|
||||||
|
|
||||||
|
await llmSimulator.next({ content: 'Looks like this was truncated' });
|
||||||
|
|
||||||
|
await llmSimulator.complete();
|
||||||
|
|
||||||
|
await finished(stream);
|
||||||
|
});
|
||||||
|
it('truncates the message', () => {
|
||||||
|
const body = JSON.parse(
|
||||||
|
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
|
||||||
|
);
|
||||||
|
|
||||||
|
const parsed = JSON.parse(last(body.messages as ChatCompletionResponseMessage[])!.content!);
|
||||||
|
|
||||||
|
expect(parsed).toEqual({
|
||||||
|
message: 'Function response exceeded the maximum length allowed and was truncated',
|
||||||
|
truncated: expect.any(String),
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(parsed.truncated.includes('word ')).toBe(true);
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
@ -10,12 +10,13 @@ import type { ActionsClient } from '@kbn/actions-plugin/server';
|
||||||
import type { ElasticsearchClient } from '@kbn/core/server';
|
import type { ElasticsearchClient } from '@kbn/core/server';
|
||||||
import type { Logger } from '@kbn/logging';
|
import type { Logger } from '@kbn/logging';
|
||||||
import type { PublicMethodsOf } from '@kbn/utility-types';
|
import type { PublicMethodsOf } from '@kbn/utility-types';
|
||||||
import { compact, isEmpty, last, merge, omit, pick } from 'lodash';
|
|
||||||
import type {
|
import type {
|
||||||
ChatCompletionRequestMessage,
|
ChatCompletionRequestMessage,
|
||||||
CreateChatCompletionRequest,
|
CreateChatCompletionRequest,
|
||||||
CreateChatCompletionResponse,
|
CreateChatCompletionResponse,
|
||||||
} from 'openai';
|
} from 'openai';
|
||||||
|
import { decode, encode } from 'gpt-tokenizer';
|
||||||
|
import { compact, isEmpty, last, merge, omit, pick, take } from 'lodash';
|
||||||
import { isObservable, lastValueFrom } from 'rxjs';
|
import { isObservable, lastValueFrom } from 'rxjs';
|
||||||
import { PassThrough, Readable } from 'stream';
|
import { PassThrough, Readable } from 'stream';
|
||||||
import { v4 } from 'uuid';
|
import { v4 } from 'uuid';
|
||||||
|
@ -176,6 +177,11 @@ export class ObservabilityAIAssistantClient {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let numFunctionsCalled: number = 0;
|
||||||
|
|
||||||
|
const MAX_FUNCTION_CALLS = 3;
|
||||||
|
const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000;
|
||||||
|
|
||||||
const next = async (nextMessages: Message[]): Promise<void> => {
|
const next = async (nextMessages: Message[]): Promise<void> => {
|
||||||
const lastMessage = last(nextMessages);
|
const lastMessage = last(nextMessages);
|
||||||
|
|
||||||
|
@ -222,7 +228,10 @@ export class ObservabilityAIAssistantClient {
|
||||||
connectorId,
|
connectorId,
|
||||||
stream: true,
|
stream: true,
|
||||||
signal,
|
signal,
|
||||||
functions: functionClient
|
functions:
|
||||||
|
numFunctionsCalled >= MAX_FUNCTION_CALLS
|
||||||
|
? []
|
||||||
|
: functionClient
|
||||||
.getFunctions()
|
.getFunctions()
|
||||||
.map((fn) => pick(fn.definition, 'name', 'description', 'parameters')),
|
.map((fn) => pick(fn.definition, 'name', 'description', 'parameters')),
|
||||||
})
|
})
|
||||||
|
@ -232,7 +241,15 @@ export class ObservabilityAIAssistantClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isAssistantMessageWithFunctionRequest) {
|
if (isAssistantMessageWithFunctionRequest) {
|
||||||
const functionResponse = await functionClient
|
const functionResponse =
|
||||||
|
numFunctionsCalled >= MAX_FUNCTION_CALLS
|
||||||
|
? {
|
||||||
|
content: {
|
||||||
|
error: {},
|
||||||
|
message: 'Function limit exceeded, ask the user what to do next',
|
||||||
|
},
|
||||||
|
}
|
||||||
|
: await functionClient
|
||||||
.executeFunction({
|
.executeFunction({
|
||||||
connectorId,
|
connectorId,
|
||||||
name: lastMessage.message.function_call!.name,
|
name: lastMessage.message.function_call!.name,
|
||||||
|
@ -240,6 +257,26 @@ export class ObservabilityAIAssistantClient {
|
||||||
args: lastMessage.message.function_call!.arguments,
|
args: lastMessage.message.function_call!.arguments,
|
||||||
signal,
|
signal,
|
||||||
})
|
})
|
||||||
|
.then((response) => {
|
||||||
|
if (isObservable(response)) {
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
const encoded = encode(JSON.stringify(response.content || {}));
|
||||||
|
|
||||||
|
if (encoded.length <= MAX_FUNCTION_RESPONSE_TOKEN_COUNT) {
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
data: response.data,
|
||||||
|
content: {
|
||||||
|
message:
|
||||||
|
'Function response exceeded the maximum length allowed and was truncated',
|
||||||
|
truncated: decode(take(encoded, MAX_FUNCTION_RESPONSE_TOKEN_COUNT)),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
})
|
||||||
.catch((error): FunctionResponse => {
|
.catch((error): FunctionResponse => {
|
||||||
return {
|
return {
|
||||||
content: {
|
content: {
|
||||||
|
@ -249,6 +286,8 @@ export class ObservabilityAIAssistantClient {
|
||||||
};
|
};
|
||||||
});
|
});
|
||||||
|
|
||||||
|
numFunctionsCalled++;
|
||||||
|
|
||||||
if (signal.aborted) {
|
if (signal.aborted) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue