mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 17:59:23 -04:00
# Backport This will backport the following commits from `main` to `8.12`: - [[Obs AI Assistant] Add guardrails (#174060)](https://github.com/elastic/kibana/pull/174060) <!--- Backport version: 9.4.3 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) <!--BACKPORT [{"author":{"name":"Dario Gieselaar","email":"dario.gieselaar@elastic.co"},"sourceCommit":{"committedDate":"2024-01-02T18:35:49Z","message":"[Obs AI Assistant] Add guardrails (#174060)\n\nAdd guardrails against function looping (max 3 calls in a completion\r\nrequest) and long function responses (max 4000 tokens).\r\n\r\n---------\r\n\r\nCo-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>","sha":"725a21727af57b11bdd709966d20f71f9c9be16c","branchLabelMapping":{"^v8.13.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v8.12.0","v8.12.1","v8.13.0"],"title":"[Obs AI Assistant] Add guardrails","number":174060,"url":"https://github.com/elastic/kibana/pull/174060","mergeCommit":{"message":"[Obs AI Assistant] Add guardrails (#174060)\n\nAdd guardrails against function looping (max 3 calls in a completion\r\nrequest) and long function responses (max 4000 tokens).\r\n\r\n---------\r\n\r\nCo-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>","sha":"725a21727af57b11bdd709966d20f71f9c9be16c"}},"sourceBranch":"main","suggestedTargetBranches":["8.12"],"targetPullRequestStates":[{"branch":"8.12","label":"v8.12.0","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"},{"branch":"main","label":"v8.13.0","branchLabelMappingKey":"^v8.13.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/174060","number":174060,"mergeCommit":{"message":"[Obs AI Assistant] Add guardrails (#174060)\n\nAdd guardrails against function looping (max 3 calls in a completion\r\nrequest) and long function responses (max 4000 tokens).\r\n\r\n---------\r\n\r\nCo-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>","sha":"725a21727af57b11bdd709966d20f71f9c9be16c"}}]}] BACKPORT--> Co-authored-by: Dario Gieselaar <dario.gieselaar@elastic.co>
This commit is contained in:
parent
fdd0dc4ce5
commit
a38fc92c30
2 changed files with 281 additions and 36 deletions
|
@ -7,9 +7,11 @@
|
|||
import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client';
|
||||
import type { ElasticsearchClient, Logger } from '@kbn/core/server';
|
||||
import type { DeeplyMockedKeys } from '@kbn/utility-types-jest';
|
||||
import { merge } from 'lodash';
|
||||
import { waitFor } from '@testing-library/react';
|
||||
import { last, merge, repeat } from 'lodash';
|
||||
import { ChatCompletionResponseMessage } from 'openai';
|
||||
import { Subject } from 'rxjs';
|
||||
import { PassThrough, type Readable } from 'stream';
|
||||
import { EventEmitter, PassThrough, type Readable } from 'stream';
|
||||
import { finished } from 'stream/promises';
|
||||
import { ObservabilityAIAssistantClient } from '.';
|
||||
import { createResourceNamesMap } from '..';
|
||||
|
@ -70,7 +72,7 @@ function createLlmSimulator() {
|
|||
};
|
||||
}
|
||||
|
||||
describe('Observability AI Assistant service', () => {
|
||||
describe('Observability AI Assistant client', () => {
|
||||
let client: ObservabilityAIAssistantClient;
|
||||
|
||||
const actionsClientMock: DeeplyMockedKeys<ActionsClient> = {
|
||||
|
@ -84,14 +86,8 @@ describe('Observability AI Assistant service', () => {
|
|||
} as any;
|
||||
|
||||
const currentUserEsClientMock: DeeplyMockedKeys<ElasticsearchClient> = {
|
||||
search: jest.fn().mockResolvedValue({
|
||||
hits: {
|
||||
hits: [],
|
||||
},
|
||||
}),
|
||||
fieldCaps: jest.fn().mockResolvedValue({
|
||||
fields: [],
|
||||
}),
|
||||
search: jest.fn(),
|
||||
fieldCaps: jest.fn(),
|
||||
} as any;
|
||||
|
||||
const knowledgeBaseServiceMock: DeeplyMockedKeys<KnowledgeBaseService> = {
|
||||
|
@ -107,16 +103,29 @@ describe('Observability AI Assistant service', () => {
|
|||
|
||||
const functionClientMock: DeeplyMockedKeys<ChatFunctionClient> = {
|
||||
executeFunction: jest.fn(),
|
||||
getFunctions: jest.fn().mockReturnValue([]),
|
||||
hasFunction: jest.fn().mockImplementation((name) => {
|
||||
return name !== 'recall';
|
||||
}),
|
||||
getFunctions: jest.fn(),
|
||||
hasFunction: jest.fn(),
|
||||
} as any;
|
||||
|
||||
let llmSimulator: LlmSimulator;
|
||||
|
||||
function createClient() {
|
||||
jest.clearAllMocks();
|
||||
jest.resetAllMocks();
|
||||
|
||||
functionClientMock.getFunctions.mockReturnValue([]);
|
||||
functionClientMock.hasFunction.mockImplementation((name) => {
|
||||
return name !== 'recall';
|
||||
});
|
||||
|
||||
currentUserEsClientMock.search.mockResolvedValue({
|
||||
hits: {
|
||||
hits: [],
|
||||
},
|
||||
} as any);
|
||||
|
||||
currentUserEsClientMock.fieldCaps.mockResolvedValue({
|
||||
fields: [],
|
||||
} as any);
|
||||
|
||||
return new ObservabilityAIAssistantClient({
|
||||
actionsClient: actionsClientMock,
|
||||
|
@ -158,6 +167,10 @@ describe('Observability AI Assistant service', () => {
|
|||
);
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('when completing a conversation without an initial conversation id', () => {
|
||||
let stream: Readable;
|
||||
|
||||
|
@ -1148,4 +1161,197 @@ describe('Observability AI Assistant service', () => {
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('when the LLM keeps on calling a function and the limit has been exceeded', () => {
|
||||
let stream: Readable;
|
||||
|
||||
let dataHandler: jest.Mock;
|
||||
|
||||
beforeEach(async () => {
|
||||
client = createClient();
|
||||
|
||||
const onLlmCall = new EventEmitter();
|
||||
|
||||
function waitForNextLlmCall() {
|
||||
return new Promise<void>((resolve) => onLlmCall.addListener('next', resolve));
|
||||
}
|
||||
|
||||
actionsClientMock.execute.mockImplementation(async () => {
|
||||
llmSimulator = createLlmSimulator();
|
||||
onLlmCall.emit('next');
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: llmSimulator.stream,
|
||||
};
|
||||
});
|
||||
|
||||
functionClientMock.getFunctions.mockImplementation(() => [
|
||||
{
|
||||
definition: {
|
||||
name: 'get_top_alerts',
|
||||
contexts: ['core'],
|
||||
description: '',
|
||||
parameters: {},
|
||||
},
|
||||
respond: async () => {
|
||||
return { content: 'Call this function again' };
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
functionClientMock.hasFunction.mockImplementation((name) => name === 'get_top_alerts');
|
||||
functionClientMock.executeFunction.mockImplementation(async () => ({
|
||||
content: 'Call this function again',
|
||||
}));
|
||||
|
||||
stream = await client.complete({
|
||||
connectorId: 'foo',
|
||||
messages: [system('This is a system message'), user('How many alerts do I have?')],
|
||||
functionClient: functionClientMock,
|
||||
signal: new AbortController().signal,
|
||||
title: 'My predefined title',
|
||||
persist: true,
|
||||
});
|
||||
|
||||
dataHandler = jest.fn();
|
||||
|
||||
stream.on('data', dataHandler);
|
||||
|
||||
async function requestAlertsFunctionCall() {
|
||||
const body = JSON.parse(
|
||||
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
|
||||
);
|
||||
|
||||
let nextLlmCallPromise: Promise<void>;
|
||||
|
||||
if (body.functions?.length) {
|
||||
nextLlmCallPromise = waitForNextLlmCall();
|
||||
await llmSimulator.next({ function_call: { name: 'get_top_alerts' } });
|
||||
} else {
|
||||
nextLlmCallPromise = Promise.resolve();
|
||||
await llmSimulator.next({ content: 'Looks like we are done here' });
|
||||
}
|
||||
|
||||
await llmSimulator.complete();
|
||||
|
||||
await nextLlmCallPromise;
|
||||
}
|
||||
|
||||
await requestAlertsFunctionCall();
|
||||
|
||||
await requestAlertsFunctionCall();
|
||||
|
||||
await requestAlertsFunctionCall();
|
||||
|
||||
await requestAlertsFunctionCall();
|
||||
|
||||
await finished(stream);
|
||||
});
|
||||
|
||||
it('executed the function no more than three times', () => {
|
||||
expect(functionClientMock.executeFunction).toHaveBeenCalledTimes(3);
|
||||
});
|
||||
|
||||
it('does not give the LLM the choice to call a function anymore', () => {
|
||||
const firstBody = JSON.parse(
|
||||
(actionsClientMock.execute.mock.calls[0][0].params as any).subActionParams.body
|
||||
);
|
||||
const body = JSON.parse(
|
||||
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
|
||||
);
|
||||
|
||||
expect(firstBody.functions.length).toBe(1);
|
||||
|
||||
expect(body.functions).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('when the function response exceeds the max no of tokens for one', () => {
|
||||
let stream: Readable;
|
||||
|
||||
let dataHandler: jest.Mock;
|
||||
|
||||
beforeEach(async () => {
|
||||
client = createClient();
|
||||
|
||||
let functionResponsePromiseResolve: Function | undefined;
|
||||
|
||||
actionsClientMock.execute.mockImplementation(async () => {
|
||||
llmSimulator = createLlmSimulator();
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: llmSimulator.stream,
|
||||
};
|
||||
});
|
||||
|
||||
functionClientMock.getFunctions.mockImplementation(() => [
|
||||
{
|
||||
definition: {
|
||||
name: 'get_top_alerts',
|
||||
contexts: ['core'],
|
||||
description: '',
|
||||
parameters: {},
|
||||
},
|
||||
respond: async () => {
|
||||
return { content: '' };
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
functionClientMock.hasFunction.mockImplementation((name) => name === 'get_top_alerts');
|
||||
|
||||
functionClientMock.executeFunction.mockImplementation(() => {
|
||||
return new Promise((resolve) => {
|
||||
functionResponsePromiseResolve = resolve;
|
||||
});
|
||||
});
|
||||
|
||||
stream = await client.complete({
|
||||
connectorId: 'foo',
|
||||
messages: [system('This is a system message'), user('How many alerts do I have?')],
|
||||
functionClient: functionClientMock,
|
||||
signal: new AbortController().signal,
|
||||
title: 'My predefined title',
|
||||
persist: true,
|
||||
});
|
||||
|
||||
dataHandler = jest.fn();
|
||||
|
||||
stream.on('data', dataHandler);
|
||||
|
||||
await llmSimulator.next({ function_call: { name: 'get_top_alerts' } });
|
||||
|
||||
await llmSimulator.complete();
|
||||
|
||||
await waitFor(() => functionResponsePromiseResolve !== undefined);
|
||||
|
||||
functionResponsePromiseResolve!({
|
||||
content: repeat('word ', 10000),
|
||||
});
|
||||
|
||||
await waitFor(() => actionsClientMock.execute.mock.calls.length > 1);
|
||||
|
||||
await llmSimulator.next({ content: 'Looks like this was truncated' });
|
||||
|
||||
await llmSimulator.complete();
|
||||
|
||||
await finished(stream);
|
||||
});
|
||||
it('truncates the message', () => {
|
||||
const body = JSON.parse(
|
||||
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
|
||||
);
|
||||
|
||||
const parsed = JSON.parse(last(body.messages as ChatCompletionResponseMessage[])!.content!);
|
||||
|
||||
expect(parsed).toEqual({
|
||||
message: 'Function response exceeded the maximum length allowed and was truncated',
|
||||
truncated: expect.any(String),
|
||||
});
|
||||
|
||||
expect(parsed.truncated.includes('word ')).toBe(true);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -10,12 +10,13 @@ import type { ActionsClient } from '@kbn/actions-plugin/server';
|
|||
import type { ElasticsearchClient } from '@kbn/core/server';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import type { PublicMethodsOf } from '@kbn/utility-types';
|
||||
import { compact, isEmpty, last, merge, omit, pick } from 'lodash';
|
||||
import type {
|
||||
ChatCompletionRequestMessage,
|
||||
CreateChatCompletionRequest,
|
||||
CreateChatCompletionResponse,
|
||||
} from 'openai';
|
||||
import { decode, encode } from 'gpt-tokenizer';
|
||||
import { compact, isEmpty, last, merge, omit, pick, take } from 'lodash';
|
||||
import { isObservable, lastValueFrom } from 'rxjs';
|
||||
import { PassThrough, Readable } from 'stream';
|
||||
import { v4 } from 'uuid';
|
||||
|
@ -176,6 +177,11 @@ export class ObservabilityAIAssistantClient {
|
|||
});
|
||||
}
|
||||
|
||||
let numFunctionsCalled: number = 0;
|
||||
|
||||
const MAX_FUNCTION_CALLS = 3;
|
||||
const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000;
|
||||
|
||||
const next = async (nextMessages: Message[]): Promise<void> => {
|
||||
const lastMessage = last(nextMessages);
|
||||
|
||||
|
@ -222,9 +228,12 @@ export class ObservabilityAIAssistantClient {
|
|||
connectorId,
|
||||
stream: true,
|
||||
signal,
|
||||
functions: functionClient
|
||||
.getFunctions()
|
||||
.map((fn) => pick(fn.definition, 'name', 'description', 'parameters')),
|
||||
functions:
|
||||
numFunctionsCalled >= MAX_FUNCTION_CALLS
|
||||
? []
|
||||
: functionClient
|
||||
.getFunctions()
|
||||
.map((fn) => pick(fn.definition, 'name', 'description', 'parameters')),
|
||||
})
|
||||
).pipe(processOpenAiStream()),
|
||||
});
|
||||
|
@ -232,22 +241,52 @@ export class ObservabilityAIAssistantClient {
|
|||
}
|
||||
|
||||
if (isAssistantMessageWithFunctionRequest) {
|
||||
const functionResponse = await functionClient
|
||||
.executeFunction({
|
||||
connectorId,
|
||||
name: lastMessage.message.function_call!.name,
|
||||
messages: nextMessages,
|
||||
args: lastMessage.message.function_call!.arguments,
|
||||
signal,
|
||||
})
|
||||
.catch((error): FunctionResponse => {
|
||||
return {
|
||||
content: {
|
||||
message: error.toString(),
|
||||
error,
|
||||
},
|
||||
};
|
||||
});
|
||||
const functionResponse =
|
||||
numFunctionsCalled >= MAX_FUNCTION_CALLS
|
||||
? {
|
||||
content: {
|
||||
error: {},
|
||||
message: 'Function limit exceeded, ask the user what to do next',
|
||||
},
|
||||
}
|
||||
: await functionClient
|
||||
.executeFunction({
|
||||
connectorId,
|
||||
name: lastMessage.message.function_call!.name,
|
||||
messages: nextMessages,
|
||||
args: lastMessage.message.function_call!.arguments,
|
||||
signal,
|
||||
})
|
||||
.then((response) => {
|
||||
if (isObservable(response)) {
|
||||
return response;
|
||||
}
|
||||
|
||||
const encoded = encode(JSON.stringify(response.content || {}));
|
||||
|
||||
if (encoded.length <= MAX_FUNCTION_RESPONSE_TOKEN_COUNT) {
|
||||
return response;
|
||||
}
|
||||
|
||||
return {
|
||||
data: response.data,
|
||||
content: {
|
||||
message:
|
||||
'Function response exceeded the maximum length allowed and was truncated',
|
||||
truncated: decode(take(encoded, MAX_FUNCTION_RESPONSE_TOKEN_COUNT)),
|
||||
},
|
||||
};
|
||||
})
|
||||
.catch((error): FunctionResponse => {
|
||||
return {
|
||||
content: {
|
||||
message: error.toString(),
|
||||
error,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
numFunctionsCalled++;
|
||||
|
||||
if (signal.aborted) {
|
||||
return;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue