[8.12] [Obs AI Assistant] Add guardrails (#174060) (#174117)

# Backport

This will backport the following commits from `main` to `8.12`:
- [[Obs AI Assistant] Add guardrails
(#174060)](https://github.com/elastic/kibana/pull/174060)

<!--- Backport version: 9.4.3 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Dario
Gieselaar","email":"dario.gieselaar@elastic.co"},"sourceCommit":{"committedDate":"2024-01-02T18:35:49Z","message":"[Obs
AI Assistant] Add guardrails (#174060)\n\nAdd guardrails against
function looping (max 3 calls in a completion\r\nrequest) and long
function responses (max 4000
tokens).\r\n\r\n---------\r\n\r\nCo-authored-by: Kibana Machine
<42973632+kibanamachine@users.noreply.github.com>","sha":"725a21727af57b11bdd709966d20f71f9c9be16c","branchLabelMapping":{"^v8.13.0$":"main","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v8.12.0","v8.12.1","v8.13.0"],"title":"[Obs
AI Assistant] Add
guardrails","number":174060,"url":"https://github.com/elastic/kibana/pull/174060","mergeCommit":{"message":"[Obs
AI Assistant] Add guardrails (#174060)\n\nAdd guardrails against
function looping (max 3 calls in a completion\r\nrequest) and long
function responses (max 4000
tokens).\r\n\r\n---------\r\n\r\nCo-authored-by: Kibana Machine
<42973632+kibanamachine@users.noreply.github.com>","sha":"725a21727af57b11bdd709966d20f71f9c9be16c"}},"sourceBranch":"main","suggestedTargetBranches":["8.12"],"targetPullRequestStates":[{"branch":"8.12","label":"v8.12.0","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"},{"branch":"main","label":"v8.13.0","branchLabelMappingKey":"^v8.13.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/174060","number":174060,"mergeCommit":{"message":"[Obs
AI Assistant] Add guardrails (#174060)\n\nAdd guardrails against
function looping (max 3 calls in a completion\r\nrequest) and long
function responses (max 4000
tokens).\r\n\r\n---------\r\n\r\nCo-authored-by: Kibana Machine
<42973632+kibanamachine@users.noreply.github.com>","sha":"725a21727af57b11bdd709966d20f71f9c9be16c"}}]}]
BACKPORT-->

Co-authored-by: Dario Gieselaar <dario.gieselaar@elastic.co>
This commit is contained in:
Kibana Machine 2024-01-02 14:48:59 -05:00 committed by GitHub
parent fdd0dc4ce5
commit a38fc92c30
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 281 additions and 36 deletions

View file

@ -7,9 +7,11 @@
import type { ActionsClient } from '@kbn/actions-plugin/server/actions_client';
import type { ElasticsearchClient, Logger } from '@kbn/core/server';
import type { DeeplyMockedKeys } from '@kbn/utility-types-jest';
import { merge } from 'lodash';
import { waitFor } from '@testing-library/react';
import { last, merge, repeat } from 'lodash';
import { ChatCompletionResponseMessage } from 'openai';
import { Subject } from 'rxjs';
import { PassThrough, type Readable } from 'stream';
import { EventEmitter, PassThrough, type Readable } from 'stream';
import { finished } from 'stream/promises';
import { ObservabilityAIAssistantClient } from '.';
import { createResourceNamesMap } from '..';
@ -70,7 +72,7 @@ function createLlmSimulator() {
};
}
describe('Observability AI Assistant service', () => {
describe('Observability AI Assistant client', () => {
let client: ObservabilityAIAssistantClient;
const actionsClientMock: DeeplyMockedKeys<ActionsClient> = {
@ -84,14 +86,8 @@ describe('Observability AI Assistant service', () => {
} as any;
const currentUserEsClientMock: DeeplyMockedKeys<ElasticsearchClient> = {
search: jest.fn().mockResolvedValue({
hits: {
hits: [],
},
}),
fieldCaps: jest.fn().mockResolvedValue({
fields: [],
}),
search: jest.fn(),
fieldCaps: jest.fn(),
} as any;
const knowledgeBaseServiceMock: DeeplyMockedKeys<KnowledgeBaseService> = {
@ -107,16 +103,29 @@ describe('Observability AI Assistant service', () => {
const functionClientMock: DeeplyMockedKeys<ChatFunctionClient> = {
executeFunction: jest.fn(),
getFunctions: jest.fn().mockReturnValue([]),
hasFunction: jest.fn().mockImplementation((name) => {
return name !== 'recall';
}),
getFunctions: jest.fn(),
hasFunction: jest.fn(),
} as any;
let llmSimulator: LlmSimulator;
function createClient() {
jest.clearAllMocks();
jest.resetAllMocks();
functionClientMock.getFunctions.mockReturnValue([]);
functionClientMock.hasFunction.mockImplementation((name) => {
return name !== 'recall';
});
currentUserEsClientMock.search.mockResolvedValue({
hits: {
hits: [],
},
} as any);
currentUserEsClientMock.fieldCaps.mockResolvedValue({
fields: [],
} as any);
return new ObservabilityAIAssistantClient({
actionsClient: actionsClientMock,
@ -158,6 +167,10 @@ describe('Observability AI Assistant service', () => {
);
}
beforeEach(() => {
jest.clearAllMocks();
});
describe('when completing a conversation without an initial conversation id', () => {
let stream: Readable;
@ -1148,4 +1161,197 @@ describe('Observability AI Assistant service', () => {
});
});
});
describe('when the LLM keeps on calling a function and the limit has been exceeded', () => {
let stream: Readable;
let dataHandler: jest.Mock;
beforeEach(async () => {
client = createClient();
const onLlmCall = new EventEmitter();
function waitForNextLlmCall() {
return new Promise<void>((resolve) => onLlmCall.addListener('next', resolve));
}
actionsClientMock.execute.mockImplementation(async () => {
llmSimulator = createLlmSimulator();
onLlmCall.emit('next');
return {
actionId: '',
status: 'ok',
data: llmSimulator.stream,
};
});
functionClientMock.getFunctions.mockImplementation(() => [
{
definition: {
name: 'get_top_alerts',
contexts: ['core'],
description: '',
parameters: {},
},
respond: async () => {
return { content: 'Call this function again' };
},
},
]);
functionClientMock.hasFunction.mockImplementation((name) => name === 'get_top_alerts');
functionClientMock.executeFunction.mockImplementation(async () => ({
content: 'Call this function again',
}));
stream = await client.complete({
connectorId: 'foo',
messages: [system('This is a system message'), user('How many alerts do I have?')],
functionClient: functionClientMock,
signal: new AbortController().signal,
title: 'My predefined title',
persist: true,
});
dataHandler = jest.fn();
stream.on('data', dataHandler);
async function requestAlertsFunctionCall() {
const body = JSON.parse(
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
);
let nextLlmCallPromise: Promise<void>;
if (body.functions?.length) {
nextLlmCallPromise = waitForNextLlmCall();
await llmSimulator.next({ function_call: { name: 'get_top_alerts' } });
} else {
nextLlmCallPromise = Promise.resolve();
await llmSimulator.next({ content: 'Looks like we are done here' });
}
await llmSimulator.complete();
await nextLlmCallPromise;
}
await requestAlertsFunctionCall();
await requestAlertsFunctionCall();
await requestAlertsFunctionCall();
await requestAlertsFunctionCall();
await finished(stream);
});
it('executed the function no more than three times', () => {
expect(functionClientMock.executeFunction).toHaveBeenCalledTimes(3);
});
it('does not give the LLM the choice to call a function anymore', () => {
const firstBody = JSON.parse(
(actionsClientMock.execute.mock.calls[0][0].params as any).subActionParams.body
);
const body = JSON.parse(
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
);
expect(firstBody.functions.length).toBe(1);
expect(body.functions).toBeUndefined();
});
});
describe('when the function response exceeds the max no of tokens for one', () => {
let stream: Readable;
let dataHandler: jest.Mock;
beforeEach(async () => {
client = createClient();
let functionResponsePromiseResolve: Function | undefined;
actionsClientMock.execute.mockImplementation(async () => {
llmSimulator = createLlmSimulator();
return {
actionId: '',
status: 'ok',
data: llmSimulator.stream,
};
});
functionClientMock.getFunctions.mockImplementation(() => [
{
definition: {
name: 'get_top_alerts',
contexts: ['core'],
description: '',
parameters: {},
},
respond: async () => {
return { content: '' };
},
},
]);
functionClientMock.hasFunction.mockImplementation((name) => name === 'get_top_alerts');
functionClientMock.executeFunction.mockImplementation(() => {
return new Promise((resolve) => {
functionResponsePromiseResolve = resolve;
});
});
stream = await client.complete({
connectorId: 'foo',
messages: [system('This is a system message'), user('How many alerts do I have?')],
functionClient: functionClientMock,
signal: new AbortController().signal,
title: 'My predefined title',
persist: true,
});
dataHandler = jest.fn();
stream.on('data', dataHandler);
await llmSimulator.next({ function_call: { name: 'get_top_alerts' } });
await llmSimulator.complete();
await waitFor(() => functionResponsePromiseResolve !== undefined);
functionResponsePromiseResolve!({
content: repeat('word ', 10000),
});
await waitFor(() => actionsClientMock.execute.mock.calls.length > 1);
await llmSimulator.next({ content: 'Looks like this was truncated' });
await llmSimulator.complete();
await finished(stream);
});
it('truncates the message', () => {
const body = JSON.parse(
(actionsClientMock.execute.mock.lastCall![0].params as any).subActionParams.body
);
const parsed = JSON.parse(last(body.messages as ChatCompletionResponseMessage[])!.content!);
expect(parsed).toEqual({
message: 'Function response exceeded the maximum length allowed and was truncated',
truncated: expect.any(String),
});
expect(parsed.truncated.includes('word ')).toBe(true);
});
});
});

View file

@ -10,12 +10,13 @@ import type { ActionsClient } from '@kbn/actions-plugin/server';
import type { ElasticsearchClient } from '@kbn/core/server';
import type { Logger } from '@kbn/logging';
import type { PublicMethodsOf } from '@kbn/utility-types';
import { compact, isEmpty, last, merge, omit, pick } from 'lodash';
import type {
ChatCompletionRequestMessage,
CreateChatCompletionRequest,
CreateChatCompletionResponse,
} from 'openai';
import { decode, encode } from 'gpt-tokenizer';
import { compact, isEmpty, last, merge, omit, pick, take } from 'lodash';
import { isObservable, lastValueFrom } from 'rxjs';
import { PassThrough, Readable } from 'stream';
import { v4 } from 'uuid';
@ -176,6 +177,11 @@ export class ObservabilityAIAssistantClient {
});
}
let numFunctionsCalled: number = 0;
const MAX_FUNCTION_CALLS = 3;
const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000;
const next = async (nextMessages: Message[]): Promise<void> => {
const lastMessage = last(nextMessages);
@ -222,9 +228,12 @@ export class ObservabilityAIAssistantClient {
connectorId,
stream: true,
signal,
functions: functionClient
.getFunctions()
.map((fn) => pick(fn.definition, 'name', 'description', 'parameters')),
functions:
numFunctionsCalled >= MAX_FUNCTION_CALLS
? []
: functionClient
.getFunctions()
.map((fn) => pick(fn.definition, 'name', 'description', 'parameters')),
})
).pipe(processOpenAiStream()),
});
@ -232,22 +241,52 @@ export class ObservabilityAIAssistantClient {
}
if (isAssistantMessageWithFunctionRequest) {
const functionResponse = await functionClient
.executeFunction({
connectorId,
name: lastMessage.message.function_call!.name,
messages: nextMessages,
args: lastMessage.message.function_call!.arguments,
signal,
})
.catch((error): FunctionResponse => {
return {
content: {
message: error.toString(),
error,
},
};
});
const functionResponse =
numFunctionsCalled >= MAX_FUNCTION_CALLS
? {
content: {
error: {},
message: 'Function limit exceeded, ask the user what to do next',
},
}
: await functionClient
.executeFunction({
connectorId,
name: lastMessage.message.function_call!.name,
messages: nextMessages,
args: lastMessage.message.function_call!.arguments,
signal,
})
.then((response) => {
if (isObservable(response)) {
return response;
}
const encoded = encode(JSON.stringify(response.content || {}));
if (encoded.length <= MAX_FUNCTION_RESPONSE_TOKEN_COUNT) {
return response;
}
return {
data: response.data,
content: {
message:
'Function response exceeded the maximum length allowed and was truncated',
truncated: decode(take(encoded, MAX_FUNCTION_RESPONSE_TOKEN_COUNT)),
},
};
})
.catch((error): FunctionResponse => {
return {
content: {
message: error.toString(),
error,
},
};
});
numFunctionsCalled++;
if (signal.aborted) {
return;