[Obs AI Assistant] Add API test for get_alerts_dataset_info tool (#212858)

Follow-up to: https://github.com/elastic/kibana/pull/212077

This PR includes an API test that covers `get_alerts_dataset_info` and
would have caught the bug fixed in
https://github.com/elastic/kibana/pull/212077.

It also contains the following bug fixes:

- Fix system message in `select_relevant_fields`
- Change prompt in `select_relevant_fields` so that the LLM consistently
uses the right format when responding.
This commit is contained in:
Søren Louv-Jansen 2025-03-05 09:09:22 +01:00 committed by GitHub
parent 752af8338e
commit 0fb83efd82
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 728 additions and 159 deletions

View file

@ -13,6 +13,12 @@ import { MessageRole, ShortIdTable, type Message } from '../../../common';
import { concatenateChatCompletionChunks } from '../../../common/utils/concatenate_chat_completion_chunks';
import { FunctionCallChatFunction } from '../../service/types';
const SELECT_RELEVANT_FIELDS_NAME = 'select_relevant_fields';
export const GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE = `You are a helpful assistant for Elastic Observability.
Your task is to determine which fields are relevant to the conversation by selecting only the field IDs from the provided list.
The list in the user message consists of JSON objects that map a human-readable "field" name to its unique "id".
You must not output any field names only the corresponding "id" values. Ensure that your output follows the exact JSON format specified.`;
export async function getRelevantFieldNames({
index,
start,
@ -100,11 +106,7 @@ export async function getRelevantFieldNames({
await chat('get_relevant_dataset_names', {
signal,
stream: true,
systemMessage: `You are a helpful assistant for Elastic Observability.
Your task is to create a list of field names that are relevant
to the conversation, using ONLY the list of fields and
types provided in the last user message. DO NOT UNDER ANY
CIRCUMSTANCES include fields not mentioned in this list.`,
systemMessage: GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE,
messages: [
// remove the last function request
...messages.slice(0, -1),
@ -112,7 +114,7 @@ export async function getRelevantFieldNames({
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: `This is the list:
content: `Below is a list of fields. Each entry is a JSON object that contains a "field" (the field name) and an "id" (the unique identifier). Use only the "id" values from this list when selecting relevant fields:
${fieldsInChunk
.map((field) => JSON.stringify({ field, id: shortIdTable.take(field) }))
@ -122,8 +124,12 @@ export async function getRelevantFieldNames({
],
functions: [
{
name: 'select_relevant_fields',
description: 'The IDs of the fields you consider relevant to the conversation',
name: SELECT_RELEVANT_FIELDS_NAME,
description: `Return only the field IDs (from the provided list) that you consider relevant to the conversation. Do not use any of the field names. Your response must be in the exact JSON format:
{
"fieldIds": ["id1", "id2", "id3"]
}
Only include IDs from the list provided in the user message.`,
parameters: {
type: 'object',
properties: {
@ -138,7 +144,7 @@ export async function getRelevantFieldNames({
} as const,
},
],
functionCall: 'select_relevant_fields',
functionCall: SELECT_RELEVANT_FIELDS_NAME,
})
).pipe(concatenateChatCompletionChunks());

View file

@ -39,52 +39,46 @@ export const registerFunctions: RegistrationCallback = async ({
};
const isServerless = !!resources.plugins.serverless;
if (scopes.includes('observability')) {
functions.registerInstruction(`You are a helpful assistant for Elastic Observability. Your goal is to help the Elastic Observability users to quickly assess what is happening in their observed systems. You can help them visualise and analyze data, investigate their systems, perform root cause analysis or identify optimisation opportunities.
It's very important to not assume what the user is meaning. Ask them for clarification if needed.
const isObservabilityDeployment = scopes.includes('observability');
const isGenericDeployment = scopes.length === 0 || (scopes.length === 1 && scopes[0] === 'all');
If you are unsure about which function should be used and with what arguments, ask the user for clarification or confirmation.
if (isObservabilityDeployment || isGenericDeployment) {
functions.registerInstruction(`
${
isObservabilityDeployment
? `You are a helpful assistant for Elastic Observability. Your goal is to help the Elastic Observability users to quickly assess what is happening in their observed systems. You can help them visualise and analyze data, investigate their systems, perform root cause analysis or identify optimisation opportunities.`
: `You are a helpful assistant for Elasticsearch. Your goal is to help Elasticsearch users accomplish tasks using Kibana and Elasticsearch. You can help them construct queries, index data, search data, use Elasticsearch APIs, generate sample data, visualise and analyze data.`
}
It's very important to not assume what the user means. Ask them for clarification if needed.
If you are unsure about which function should be used and with what arguments, ask the user for clarification or confirmation.
In KQL ("kqlFilter")) escaping happens with double quotes, not single quotes. Some characters that need escaping are: ':()\\\
/\". Always put a field value in double quotes. Best: service.name:\"opbeans-go\". Wrong: service.name:opbeans-go. This is very important!
You can use Github-flavored Markdown in your responses. If a function returns an array, consider using a Markdown table to format the response.
${
isObservabilityDeployment
? 'Note that ES|QL (the Elasticsearch Query Language which is a new piped language) is the preferred query language.'
: ''
}
If you want to call a function or tool, only call it a single time per message. Wait until the function has been executed and its results
returned to you, before executing the same tool or another tool again if needed.
In KQL ("kqlFilter")) escaping happens with double quotes, not single quotes. Some characters that need escaping are: ':()\\\
/\". Always put a field value in double quotes. Best: service.name:\"opbeans-go\". Wrong: service.name:opbeans-go. This is very important!
You can use Github-flavored Markdown in your responses. If a function returns an array, consider using a Markdown table to format the response.
Note that ES|QL (the Elasticsearch Query Language which is a new piped language) is the preferred query language.
If you want to call a function or tool, only call it a single time per message. Wait until the function has been executed and its results
returned to you, before executing the same tool or another tool again if needed.
DO NOT UNDER ANY CIRCUMSTANCES USE ES|QL syntax (\`service.name == "foo"\`) with "kqlFilter" (\`service.name:"foo"\`).
The user is able to change the language which they want you to reply in on the settings page of the AI Assistant for Observability and Search, which can be found in the ${
isServerless ? `Project settings.` : `Stack Management app under the option AI Assistants`
}.
If the user asks how to change the language, reply in the same language the user asked in.`);
}
if (scopes.length === 0 || (scopes.length === 1 && scopes[0] === 'all')) {
functions.registerInstruction(
`You are a helpful assistant for Elasticsearch. Your goal is to help Elasticsearch users accomplish tasks using Kibana and Elasticsearch. You can help them construct queries, index data, search data, use Elasticsearch APIs, generate sample data, visualise and analyze data.
It's very important to not assume what the user means. Ask them for clarification if needed.
If you are unsure about which function should be used and with what arguments, ask the user for clarification or confirmation.
In KQL ("kqlFilter")) escaping happens with double quotes, not single quotes. Some characters that need escaping are: ':()\\\
/\". Always put a field value in double quotes. Best: service.name:\"opbeans-go\". Wrong: service.name:opbeans-go. This is very important!
You can use Github-flavored Markdown in your responses. If a function returns an array, consider using a Markdown table to format the response.
If you want to call a function or tool, only call it a single time per message. Wait until the function has been executed and its results
returned to you, before executing the same tool or another tool again if needed.
The user is able to change the language which they want you to reply in on the settings page of the AI Assistant for Observability and Search, which can be found in the ${
isServerless ? `Project settings.` : `Stack Management app under the option AI Assistants`
}.
If the user asks how to change the language, reply in the same language the user asked in.`
);
${
isObservabilityDeployment
? 'DO NOT UNDER ANY CIRCUMSTANCES USE ES|QL syntax (`service.name == "foo"`) with "kqlFilter" (`service.name:"foo"`).'
: ''
}
The user is able to change the language which they want you to reply in on the settings page of the AI Assistant for Observability and Search, which can be found in the ${
isServerless ? `Project settings.` : `Stack Management app under the option AI Assistants`
}.
If the user asks how to change the language, reply in the same language the user asked in.`);
}
const { ready: isKnowledgeBaseReady } = await client.getKnowledgeBaseStatus();

View file

@ -274,8 +274,8 @@ export class ObservabilityAIAssistantClient {
chat: (name, chatParams) => {
// inject a chat function with predefined parameters
return this.chat(name, {
...chatParams,
systemMessage,
...chatParams,
signal,
simulateFunctionCalling,
connectorId,

View file

@ -112,9 +112,10 @@ export function registerAlertsFunction({
signal,
chat: (
operationName,
{ messages: nextMessages, functionCall, functions: nextFunctions }
{ messages: nextMessages, functionCall, functions: nextFunctions, systemMessage }
) => {
return chat(operationName, {
systemMessage,
messages: nextMessages,
functionCall,
functions: nextFunctions,

View file

@ -18,7 +18,6 @@ import { createFunctionResponseMessage } from '@kbn/observability-ai-assistant-p
import { convertMessagesForInference } from '@kbn/observability-ai-assistant-plugin/common/convert_messages_for_inference';
import { map } from 'rxjs';
import { v4 } from 'uuid';
import { RegisterInstructionCallback } from '@kbn/observability-ai-assistant-plugin/server/service/types';
import type { FunctionRegistrationParameters } from '..';
import { runAndValidateEsqlQuery } from './validate_esql_query';
@ -30,9 +29,12 @@ export function registerQueryFunction({
resources,
pluginsStart,
}: FunctionRegistrationParameters) {
const instruction: RegisterInstructionCallback = ({ availableFunctionNames }) =>
availableFunctionNames.includes(QUERY_FUNCTION_NAME)
? `You MUST use the "${QUERY_FUNCTION_NAME}" function when the user wants to:
functions.registerInstruction(({ availableFunctionNames }) => {
if (!availableFunctionNames.includes(QUERY_FUNCTION_NAME)) {
return;
}
return `You MUST use the "${QUERY_FUNCTION_NAME}" function when the user wants to:
- visualize data
- run any arbitrary query
- breakdown or filter ES|QL queries that are displayed on the current page
@ -48,9 +50,8 @@ export function registerQueryFunction({
even if it has been called before.
When the "visualize_query" function has been called, a visualization has been displayed to the user. DO NOT UNDER ANY CIRCUMSTANCES follow up a "visualize_query" function call with your own visualization attempt.
If the "${EXECUTE_QUERY_NAME}" function has been called, summarize these results for the user. The user does not see a visualization in this case.`
: undefined;
functions.registerInstruction(instruction);
If the "${EXECUTE_QUERY_NAME}" function has been called, summarize these results for the user. The user does not see a visualization in this case.`;
});
functions.registerFunction(
{

View file

@ -85,7 +85,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
},
});
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
expect(status).to.be(200);
});
@ -104,7 +104,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
},
});
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
const simulator = await simulatorPromise;
const requestData = simulator.requestBody; // This is the request sent to the LLM
expect(requestData.messages[0].content).to.eql(SYSTEM_MESSAGE);

View file

@ -76,7 +76,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
scopes: ['all'],
});
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
return String(response.body)
.split('\n')
@ -133,7 +133,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
await new Promise<void>((resolve) => passThrough.on('end', () => resolve()));
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
parsedEvents = decodeEvents(receivedChunks.join(''));
});
@ -243,7 +243,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
},
});
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
const simulator = await simulatorPromise;
const requestData = simulator.requestBody;
expect(requestData.messages[0].role).to.eql('system');
@ -420,7 +420,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
expect(createResponse.status).to.be(200);
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
conversationCreatedEvent = getConversationCreatedEvent(createResponse.body);
@ -463,7 +463,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
expect(updatedResponse.status).to.be(200);
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
});
after(async () => {

View file

@ -46,7 +46,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
});
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
alertsEvents = getMessageAddedEvents(alertsResponseBody);
});

View file

@ -65,7 +65,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
});
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
events = getMessageAddedEvents(responseBody);
});

View file

@ -0,0 +1,500 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MessageAddEvent, MessageRole } from '@kbn/observability-ai-assistant-plugin/common';
import expect from '@kbn/expect';
import { ApmRuleType } from '@kbn/rule-data-utils';
import { apm, timerange } from '@kbn/apm-synthtrace-client';
import { ApmSynthtraceEsClient } from '@kbn/apm-synthtrace';
import { RoleCredentials } from '@kbn/ftr-common-functional-services';
import { last } from 'lodash';
import { GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE } from '@kbn/observability-ai-assistant-plugin/server/functions/get_dataset_info/get_relevant_field_names';
import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream';
import { ApmAlertFields } from '../../../../../../../apm_api_integration/tests/alerts/helpers/alerting_api_helper';
import {
LlmProxy,
createLlmProxy,
} from '../../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy';
import { getMessageAddedEvents } from './helpers';
import type { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context';
import { APM_ALERTS_INDEX } from '../../../apm/alerts/helpers/alerting_helper';
const USER_MESSAGE = 'How many alerts do I have for the past 10 days?';
export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderContext) {
const log = getService('log');
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
const alertingApi = getService('alertingApi');
const samlAuth = getService('samlAuth');
describe('function: get_alerts_dataset_info', function () {
// Fails on MKI: https://github.com/elastic/kibana/issues/205581
this.tags(['failsOnMKI']);
let llmProxy: LlmProxy;
let connectorId: string;
let messageAddedEvents: MessageAddEvent[];
let apmSynthtraceEsClient: ApmSynthtraceEsClient;
let roleAuthc: RoleCredentials;
let createdRuleId: string;
let expectedRelevantFieldNames: string[];
let primarySystemMessage: string;
before(async () => {
({ apmSynthtraceEsClient } = await createSyntheticApmData(getService));
({ roleAuthc, createdRuleId } = await createApmErrorCountRule(getService));
llmProxy = await createLlmProxy(log);
connectorId = await observabilityAIAssistantAPIClient.createProxyActionConnector({
port: llmProxy.getPort(),
});
void llmProxy.interceptWithFunctionRequest({
name: 'get_alerts_dataset_info',
arguments: () => JSON.stringify({ start: 'now-10d', end: 'now' }),
when: () => true,
});
void llmProxy.interceptWithFunctionRequest({
name: 'select_relevant_fields',
// @ts-expect-error
when: (requestBody) => requestBody.tool_choice?.function?.name === 'select_relevant_fields',
arguments: (requestBody) => {
const userMessage = last(requestBody.messages);
const topFields = (userMessage?.content as string)
.slice(204) // remove the prefix message and only get the JSON
.trim()
.split('\n')
.map((line) => JSON.parse(line))
.slice(0, 5);
expectedRelevantFieldNames = topFields.map(({ field }) => field);
const fieldIds = topFields.map(({ id }) => id);
return JSON.stringify({ fieldIds });
},
});
void llmProxy.interceptWithFunctionRequest({
name: 'alerts',
arguments: () => JSON.stringify({ start: 'now-10d', end: 'now' }),
when: () => true,
});
void llmProxy.interceptConversation(
`You have active alerts for the past 10 days. Back to work!`
);
const { status, body } = await observabilityAIAssistantAPIClient.editor({
endpoint: 'POST /internal/observability_ai_assistant/chat/complete',
params: {
body: {
messages: [
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: USER_MESSAGE,
},
},
],
connectorId,
persist: false,
screenContexts: [],
scopes: ['observability' as const],
},
},
});
expect(status).to.be(200);
await llmProxy.waitForAllInterceptorsToHaveBeenCalled();
messageAddedEvents = getMessageAddedEvents(body);
const {
body: { systemMessage },
} = await observabilityAIAssistantAPIClient.editor({
endpoint: 'GET /internal/observability_ai_assistant/functions',
params: {
query: {
scopes: ['observability'],
},
},
});
primarySystemMessage = systemMessage;
});
after(async () => {
llmProxy.close();
await observabilityAIAssistantAPIClient.deleteActionConnector({
actionId: connectorId,
});
await apmSynthtraceEsClient.clean();
await alertingApi.cleanUpAlerts({
roleAuthc,
ruleId: createdRuleId,
alertIndexName: APM_ALERTS_INDEX,
consumer: 'apm',
});
await samlAuth.invalidateM2mApiKeyWithRoleScope(roleAuthc);
});
describe('LLM requests', () => {
let firstRequestBody: ChatCompletionStreamParams;
let secondRequestBody: ChatCompletionStreamParams;
let thirdRequestBody: ChatCompletionStreamParams;
let fourthRequestBody: ChatCompletionStreamParams;
before(async () => {
firstRequestBody = llmProxy.interceptedRequests[0].requestBody;
secondRequestBody = llmProxy.interceptedRequests[1].requestBody;
thirdRequestBody = llmProxy.interceptedRequests[2].requestBody;
fourthRequestBody = llmProxy.interceptedRequests[3].requestBody;
});
it('makes 4 requests to the LLM', () => {
expect(llmProxy.interceptedRequests.length).to.be(4);
});
describe('every request to the LLM', () => {
it('contains a system message', () => {
const everyRequestHasSystemMessage = llmProxy.interceptedRequests.every(
({ requestBody }) => {
const firstMessage = requestBody.messages[0];
return (
firstMessage.role === 'system' &&
(firstMessage.content as string).includes('You are a helpful assistant')
);
}
);
expect(everyRequestHasSystemMessage).to.be(true);
});
it('contains the original user message', () => {
const everyRequestHasUserMessage = llmProxy.interceptedRequests.every(({ requestBody }) =>
requestBody.messages.some(
(message) => message.role === 'user' && (message.content as string) === USER_MESSAGE
)
);
expect(everyRequestHasUserMessage).to.be(true);
});
it('contains the context function request and context function response', () => {
const everyRequestHasContextFunction = llmProxy.interceptedRequests.every(
({ requestBody }) => {
const hasContextFunctionRequest = requestBody.messages.some(
(message) =>
message.role === 'assistant' &&
message.tool_calls?.[0]?.function?.name === 'context'
);
const hasContextFunctionResponse = requestBody.messages.some(
(message) =>
message.role === 'tool' &&
(message.content as string).includes('screen_description') &&
(message.content as string).includes('learnings')
);
return hasContextFunctionRequest && hasContextFunctionResponse;
}
);
expect(everyRequestHasContextFunction).to.be(true);
});
});
describe('The first request', () => {
it('contains the correct number of messages', () => {
expect(firstRequestBody.messages.length).to.be(4);
});
it('contains the `get_alerts_dataset_info` tool', () => {
const hasTool = firstRequestBody.tools?.some(
(tool) => tool.function.name === 'get_alerts_dataset_info'
);
expect(hasTool).to.be(true);
});
it('leaves the function calling decision to the LLM via tool_choice=auto', () => {
expect(firstRequestBody.tool_choice).to.be('auto');
});
describe('The system message', () => {
it('has the primary system message', () => {
expect(sortSystemMessage(firstRequestBody.messages[0].content as string)).to.eql(
sortSystemMessage(primarySystemMessage)
);
});
it('has a different system message from request 2', () => {
expect(firstRequestBody.messages[0]).not.to.eql(secondRequestBody.messages[0]);
});
it('has the same system message as request 3', () => {
expect(firstRequestBody.messages[0]).to.eql(thirdRequestBody.messages[0]);
});
it('has the same system message as request 4', () => {
expect(firstRequestBody.messages[0]).to.eql(fourthRequestBody.messages[0]);
});
});
});
describe('The second request', () => {
it('contains the correct number of messages', () => {
expect(secondRequestBody.messages.length).to.be(5);
});
it('contains a system generated user message with a list of field candidates', () => {
const hasList = secondRequestBody.messages.some(
(message) =>
message.role === 'user' &&
(message.content as string).includes('Below is a list of fields.') &&
(message.content as string).includes('@timestamp')
);
expect(hasList).to.be(true);
});
it('instructs the LLM to call the `select_relevant_fields` tool via `tool_choice`', () => {
const hasToolChoice =
// @ts-expect-error
secondRequestBody.tool_choice?.function?.name === 'select_relevant_fields';
expect(hasToolChoice).to.be(true);
});
it('has a custom, function-specific system message', () => {
expect(secondRequestBody.messages[0].content).to.be(
GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE
);
});
});
describe('The third request', () => {
it('contains the correct number of messages', () => {
expect(thirdRequestBody.messages.length).to.be(6);
});
it('contains the `get_alerts_dataset_info` request', () => {
const hasFunctionRequest = thirdRequestBody.messages.some(
(message) =>
message.role === 'assistant' &&
message.tool_calls?.[0]?.function?.name === 'get_alerts_dataset_info'
);
expect(hasFunctionRequest).to.be(true);
});
it('contains the `get_alerts_dataset_info` response', () => {
const functionResponse = last(thirdRequestBody.messages);
const parsedContent = JSON.parse(functionResponse?.content as string) as {
fields: string[];
};
const fieldNamesWithType = parsedContent.fields;
const fieldNamesWithoutType = fieldNamesWithType.map((field) => field.split(':')[0]);
expect(fieldNamesWithoutType).to.eql(expectedRelevantFieldNames);
expect(fieldNamesWithType).to.eql([
'@timestamp:date',
'_id:_id',
'_ignored:string',
'_index:_index',
'_score:number',
]);
});
it('emits a messageAdded event with the `get_alerts_dataset_info` function response', async () => {
const messageWithDatasetInfo = messageAddedEvents.find(
({ message }) =>
message.message.role === MessageRole.User &&
message.message.name === 'get_alerts_dataset_info'
);
const parsedContent = JSON.parse(messageWithDatasetInfo?.message.message.content!) as {
fields: string[];
};
expect(parsedContent.fields).to.eql([
'@timestamp:date',
'_id:_id',
'_ignored:string',
'_index:_index',
'_score:number',
]);
});
it('contains the `alerts` tool', () => {
const hasTool = thirdRequestBody.tools?.some((tool) => tool.function.name === 'alerts');
expect(hasTool).to.be(true);
});
});
describe('The fourth request', () => {
it('contains the correct number of messages', () => {
expect(fourthRequestBody.messages.length).to.be(8);
});
it('contains the `alerts` request', () => {
const hasFunctionRequest = fourthRequestBody.messages.some(
(message) =>
message.role === 'assistant' && message.tool_calls?.[0]?.function?.name === 'alerts'
);
expect(hasFunctionRequest).to.be(true);
});
it('contains the `alerts` response', () => {
const functionResponseMessage = last(fourthRequestBody.messages);
const parsedContent = JSON.parse(functionResponseMessage?.content as string);
expect(Object.keys(parsedContent)).to.eql(['total', 'alerts']);
});
it('emits a messageAdded event with the `alert` function response', async () => {
const messageWithAlerts = messageAddedEvents.find(
({ message }) =>
message.message.role === MessageRole.User && message.message.name === 'alerts'
);
const parsedContent = JSON.parse(messageWithAlerts?.message.message.content!) as {
total: number;
alerts: any[];
};
expect(parsedContent.total).to.be(1);
expect(parsedContent.alerts.length).to.be(1);
});
});
});
describe('messageAdded events', () => {
it('emits 7 messageAdded events', () => {
expect(messageAddedEvents.length).to.be(7);
});
it('emits messageAdded events in the correct order', async () => {
const formattedMessageAddedEvents = messageAddedEvents.map(({ message }) => {
const { role, name, function_call: functionCall } = message.message;
if (functionCall) {
return { function_call: functionCall, role };
}
return { name, role };
});
expect(formattedMessageAddedEvents).to.eql([
{
role: 'assistant',
function_call: { name: 'context', trigger: 'assistant' },
},
{ name: 'context', role: 'user' },
{
role: 'assistant',
function_call: {
name: 'get_alerts_dataset_info',
arguments: '{"start":"now-10d","end":"now"}',
trigger: 'assistant',
},
},
{ name: 'get_alerts_dataset_info', role: 'user' },
{
role: 'assistant',
function_call: {
name: 'alerts',
arguments: '{"start":"now-10d","end":"now"}',
trigger: 'assistant',
},
},
{ name: 'alerts', role: 'user' },
{
role: 'assistant',
function_call: { name: '', arguments: '', trigger: 'assistant' },
},
]);
});
});
});
}
async function createApmErrorCountRule(
getService: DeploymentAgnosticFtrProviderContext['getService']
) {
const alertingApi = getService('alertingApi');
const samlAuth = getService('samlAuth');
const roleAuthc = await samlAuth.createM2mApiKeyWithRoleScope('editor');
const createdRule = await alertingApi.createRule({
ruleTypeId: ApmRuleType.ErrorCount,
name: 'APM error threshold',
consumer: 'apm',
schedule: { interval: '1m' },
tags: ['apm'],
params: {
environment: 'production',
threshold: 1,
windowSize: 1,
windowUnit: 'h',
},
roleAuthc,
});
const createdRuleId = createdRule.id as string;
const esResponse = await alertingApi.waitForDocumentInIndex<ApmAlertFields>({
indexName: APM_ALERTS_INDEX,
ruleId: createdRuleId,
docCountTarget: 1,
});
return {
roleAuthc,
createdRuleId,
alerts: esResponse.hits.hits.map((hit) => hit._source!),
};
}
async function createSyntheticApmData(
getService: DeploymentAgnosticFtrProviderContext['getService']
) {
const synthtrace = getService('synthtrace');
const apmSynthtraceEsClient = await synthtrace.createApmSynthtraceEsClient();
const opbeansNode = apm
.service({ name: 'opbeans-node', environment: 'production', agentName: 'node' })
.instance('instance');
const events = timerange('now-15m', 'now')
.ratePerMinute(1)
.generator((timestamp) => {
return [
opbeansNode
.transaction({ transactionName: 'DELETE /user/:id' })
.timestamp(timestamp)
.duration(100)
.failure()
.errors(
opbeansNode.error({ message: 'Unable to delete user' }).timestamp(timestamp + 50)
),
];
});
await apmSynthtraceEsClient.index(events);
return { apmSynthtraceEsClient };
}
// order of instructions can vary, so we sort to compare them
function sortSystemMessage(message: string) {
return message
.split('\n\n')
.map((line) => line.trim())
.sort();
}

View file

@ -45,6 +45,13 @@ export async function invokeChatCompleteWithFunctionRequest({
params: {
body: {
messages: [
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: 'Hello from user',
},
},
{
'@timestamp': new Date().toISOString(),
message: {

View file

@ -64,7 +64,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
});
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
});
after(async () => {

View file

@ -16,6 +16,7 @@ export default function aiAssistantApiIntegrationTests({
loadTestFile(require.resolve('./chat/chat.spec.ts'));
loadTestFile(require.resolve('./complete/complete.spec.ts'));
loadTestFile(require.resolve('./complete/functions/alerts.spec.ts'));
loadTestFile(require.resolve('./complete/functions/get_alerts_dataset_info.spec.ts'));
loadTestFile(require.resolve('./complete/functions/elasticsearch.spec.ts'));
loadTestFile(require.resolve('./complete/functions/summarize.spec.ts'));
loadTestFile(require.resolve('./public_complete/public_complete.spec.ts'));

View file

@ -308,7 +308,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
});
expect(createResponse.status).to.be(200);
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
const conversationCreatedEvent = getConversationCreatedEvent(createResponse.body);
const conversationId = conversationCreatedEvent.conversation.id;
@ -321,7 +321,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
});
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
const conversation = res.body;
return conversation;
@ -470,7 +470,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
},
});
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
const simulator = await simulatorPromise;
const requestData = simulator.requestBody;
expect(requestData.messages[0].content).to.contain(userInstructionText);

View file

@ -14,9 +14,9 @@ import {
MessageAddEvent,
type StreamingChatResponseEvent,
} from '@kbn/observability-ai-assistant-plugin/common/conversation_complete';
import type OpenAI from 'openai';
import { type AdHocInstruction } from '@kbn/observability-ai-assistant-plugin/common/types';
import type { ChatCompletionChunkToolCall } from '@kbn/inference-common';
import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream';
import {
createLlmProxy,
LlmProxy,
@ -72,7 +72,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
});
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
const titleSimulator = await titleSimulatorPromise;
const conversationSimulator = await conversationSimulatorPromise;
@ -156,7 +156,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
});
describe('after adding an instruction', () => {
let body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming;
let body: ChatCompletionStreamParams;
before(async () => {
const { conversationSimulator } = await addInterceptorsAndCallComplete({

View file

@ -7,26 +7,29 @@
import { ToolingLog } from '@kbn/tooling-log';
import getPort from 'get-port';
import { v4 as uuidv4 } from 'uuid';
import http, { type Server } from 'http';
import { isString, once, pull } from 'lodash';
import OpenAI from 'openai';
import { isString, once, pull, isFunction } from 'lodash';
import { TITLE_CONVERSATION_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/service/client/operators/get_generated_title';
import pRetry from 'p-retry';
import type { ChatCompletionChunkToolCall } from '@kbn/inference-common';
import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream';
import { createOpenAiChunk } from './create_openai_chunk';
type Request = http.IncomingMessage;
type Response = http.ServerResponse<http.IncomingMessage> & { req: http.IncomingMessage };
type LLMMessage = string[] | ToolMessage | string | undefined;
type RequestHandler = (
request: Request,
response: Response,
requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming
requestBody: ChatCompletionStreamParams
) => void;
interface RequestInterceptor {
name: string;
when: (body: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming) => boolean;
when: (body: ChatCompletionStreamParams) => boolean;
}
export interface ToolMessage {
@ -34,8 +37,8 @@ export interface ToolMessage {
tool_calls?: ChatCompletionChunkToolCall[];
}
export interface LlmResponseSimulator {
requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming;
status: (code: number) => Promise<void>;
requestBody: ChatCompletionStreamParams;
status: (code: number) => void;
next: (msg: string | ToolMessage) => Promise<void>;
error: (error: any) => Promise<void>;
complete: () => Promise<void>;
@ -46,35 +49,47 @@ export interface LlmResponseSimulator {
export class LlmProxy {
server: Server;
interval: NodeJS.Timeout;
interceptors: Array<RequestInterceptor & { handle: RequestHandler }> = [];
interceptedRequests: Array<{
requestBody: ChatCompletionStreamParams;
matchingInterceptorName: string | undefined;
}> = [];
constructor(private readonly port: number, private readonly log: ToolingLog) {
this.interval = setInterval(() => this.log.debug(`LLM proxy listening on port ${port}`), 1000);
this.interval = setInterval(() => this.log.debug(`LLM proxy listening on port ${port}`), 5000);
this.server = http
.createServer()
.on('request', async (request, response) => {
this.log.info(`LLM request received`);
const interceptors = this.interceptors.concat();
const requestBody = await getRequestBody(request);
while (interceptors.length) {
const interceptor = interceptors.shift()!;
const matchingInterceptor = this.interceptors.find(({ when }) => when(requestBody));
this.interceptedRequests.push({
requestBody,
matchingInterceptorName: matchingInterceptor?.name,
});
if (matchingInterceptor) {
this.log.info(`Handling interceptor "${matchingInterceptor.name}"`);
matchingInterceptor.handle(request, response, requestBody);
if (interceptor.when(requestBody)) {
pull(this.interceptors, interceptor);
interceptor.handle(request, response, requestBody);
return;
}
this.log.debug(`Removing interceptor "${matchingInterceptor.name}"`);
pull(this.interceptors, matchingInterceptor);
return;
}
const errorMessage = `No interceptors found to handle request: ${request.method} ${request.url}`;
const availableInterceptorNames = this.interceptors.map(({ name }) => name);
this.log.error(
`Available interceptors: ${JSON.stringify(availableInterceptorNames, null, 2)}`
);
this.log.error(
`${errorMessage}. Messages: ${JSON.stringify(requestBody.messages, null, 2)}`
);
response.writeHead(500, { errorMessage, messages: JSON.stringify(requestBody.messages) });
response.writeHead(500, {
'Elastic-Interceptor': 'Interceptor not found',
});
response.write(sseEvent({ errorMessage, availableInterceptorNames }));
response.end();
})
.on('error', (error) => {
@ -88,7 +103,8 @@ export class LlmProxy {
}
clear() {
this.interceptors.length = 0;
this.interceptors = [];
this.interceptedRequests = [];
}
close() {
@ -97,16 +113,18 @@ export class LlmProxy {
this.server.close();
}
waitForAllInterceptorsSettled() {
waitForAllInterceptorsToHaveBeenCalled() {
return pRetry(
async () => {
if (this.interceptors.length === 0) {
return;
}
const unsettledInterceptors = this.interceptors.map((i) => i.name).join(', ');
const unsettledInterceptors = this.interceptors.map((i) => i.name);
this.log.debug(
`Waiting for the following interceptors to be called: ${unsettledInterceptors}`
`Waiting for the following interceptors to be called: ${JSON.stringify(
unsettledInterceptors
)}`
);
if (this.interceptors.length > 0) {
throw new Error(`Interceptors were not called: ${unsettledInterceptors}`);
@ -120,61 +138,71 @@ export class LlmProxy {
}
interceptConversation(
msg: Array<string | ToolMessage> | ToolMessage | string | undefined,
msg: LLMMessage,
{
name = 'default_interceptor_conversation_name',
name,
}: {
name?: string;
} = {}
) {
return this.intercept(
name,
(body) => !isFunctionTitleRequest(body),
`Conversation interceptor: "${name ?? 'Unnamed'}"`,
// @ts-expect-error
(body) => body.tool_choice?.function?.name === undefined,
msg
).completeAfterIntercept();
}
interceptTitle(title: string) {
return this.intercept(
`conversation_title_interceptor_${title.split(' ').join('_')}`,
(body) => isFunctionTitleRequest(body),
{
interceptWithFunctionRequest({
name: name,
arguments: argumentsCallback,
when,
}: {
name: string;
arguments: (body: ChatCompletionStreamParams) => string;
when: RequestInterceptor['when'];
}) {
// @ts-expect-error
return this.intercept(`Function request interceptor: "${name}"`, when, (body) => {
return {
content: '',
tool_calls: [
{
index: 0,
toolCallId: 'id',
function: {
name: TITLE_CONVERSATION_FUNCTION_NAME,
arguments: JSON.stringify({ title }),
name,
arguments: argumentsCallback(body),
},
index: 0,
id: `call_${uuidv4()}`,
},
],
}
).completeAfterIntercept();
};
}).completeAfterIntercept();
}
intercept<
TResponseChunks extends
| Array<string | ToolMessage>
| ToolMessage
| string
| undefined = undefined
>(
interceptTitle(title: string) {
return this.interceptWithFunctionRequest({
name: TITLE_CONVERSATION_FUNCTION_NAME,
arguments: () => JSON.stringify({ title }),
// @ts-expect-error
when: (body) => body.tool_choice?.function?.name === TITLE_CONVERSATION_FUNCTION_NAME,
});
}
intercept(
name: string,
when: RequestInterceptor['when'],
responseChunks?: TResponseChunks
): TResponseChunks extends undefined
? { waitForIntercept: () => Promise<LlmResponseSimulator> }
: { completeAfterIntercept: () => Promise<LlmResponseSimulator> } {
responseChunks?: LLMMessage | ((body: ChatCompletionStreamParams) => LLMMessage)
): {
waitForIntercept: () => Promise<LlmResponseSimulator>;
completeAfterIntercept: () => Promise<LlmResponseSimulator>;
} {
const waitForInterceptPromise = Promise.race([
new Promise<LlmResponseSimulator>((outerResolve) => {
this.interceptors.push({
name,
when,
handle: (request, response, requestBody) => {
this.log.info(`LLM request intercepted by "${name}"`);
function write(chunk: string) {
return new Promise<void>((resolve) => response.write(chunk, () => resolve()));
}
@ -184,24 +212,28 @@ export class LlmProxy {
const simulator: LlmResponseSimulator = {
requestBody,
status: once(async (status: number) => {
status: once((status: number) => {
response.writeHead(status, {
'Elastic-Interceptor': name,
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
Connection: 'keep-alive',
});
}),
next: (msg) => {
simulator.status(200);
const chunk = createOpenAiChunk(msg);
return write(`data: ${JSON.stringify(chunk)}\n\n`);
return write(sseEvent(chunk));
},
rawWrite: (chunk: string) => {
simulator.status(200);
return write(chunk);
},
rawEnd: async () => {
await end();
},
complete: async () => {
this.log.debug(`Completed intercept for "${name}"`);
await write('data: [DONE]\n\n');
await end();
},
@ -216,29 +248,41 @@ export class LlmProxy {
});
}),
new Promise<LlmResponseSimulator>((_, reject) => {
setTimeout(() => reject(new Error(`Interceptor "${name}" timed out after 20000ms`)), 20000);
setTimeout(() => reject(new Error(`Interceptor "${name}" timed out after 30000ms`)), 30000);
}),
]);
if (responseChunks === undefined) {
return { waitForIntercept: () => waitForInterceptPromise } as any;
}
const parsedChunks = Array.isArray(responseChunks)
? responseChunks
: isString(responseChunks)
? responseChunks.split(' ').map((token, i) => (i === 0 ? token : ` ${token}`))
: [responseChunks];
return {
waitForIntercept: () => waitForInterceptPromise,
completeAfterIntercept: async () => {
const simulator = await waitForInterceptPromise;
function getParsedChunks(): Array<string | ToolMessage> {
const llmMessage = isFunction(responseChunks)
? responseChunks(simulator.requestBody)
: responseChunks;
if (!llmMessage) {
return [];
}
if (Array.isArray(llmMessage)) {
return llmMessage;
}
if (isString(llmMessage)) {
return llmMessage.split(' ').map((token, i) => (i === 0 ? token : ` ${token}`));
}
return [llmMessage];
}
const parsedChunks = getParsedChunks();
for (const chunk of parsedChunks) {
await simulator.next(chunk);
}
await simulator.complete();
return simulator;
},
} as any;
@ -251,9 +295,7 @@ export async function createLlmProxy(log: ToolingLog) {
return new LlmProxy(port, log);
}
async function getRequestBody(
request: http.IncomingMessage
): Promise<OpenAI.Chat.ChatCompletionCreateParamsNonStreaming> {
async function getRequestBody(request: http.IncomingMessage): Promise<ChatCompletionStreamParams> {
return new Promise((resolve, reject) => {
let data = '';
@ -271,11 +313,6 @@ async function getRequestBody(
});
}
export function isFunctionTitleRequest(
requestBody: OpenAI.Chat.ChatCompletionCreateParamsNonStreaming
) {
return (
requestBody.tools?.find((fn) => fn.function.name === TITLE_CONVERSATION_FUNCTION_NAME) !==
undefined
);
function sseEvent(chunk: unknown) {
return `data: ${JSON.stringify(chunk)}\n\n`;
}

View file

@ -126,7 +126,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
await openContextualInsights();
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
await retry.tryForTime(5 * 1000, async () => {
const llmResponse = await testSubjects.getVisibleText(ui.pages.contextualInsights.text);

View file

@ -247,7 +247,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
await testSubjects.setValue(ui.pages.conversations.chatInput, 'hello');
await testSubjects.pressEnter(ui.pages.conversations.chatInput);
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
await header.waitUntilLoadingHasFinished();
});
@ -256,6 +256,17 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
endpoint: 'POST /internal/observability_ai_assistant/conversations',
});
const functionResponse = await observabilityAIAssistantAPIClient.editor({
endpoint: 'GET /internal/observability_ai_assistant/functions',
params: {
query: {
scopes: ['observability'],
},
},
});
const primarySystemMessage = functionResponse.body.systemMessage;
expect(response.body.conversations.length).to.eql(2);
expect(response.body.conversations[0].conversation.title).to.be(expectedTitle);
@ -267,10 +278,13 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
const [firstUserMessage, contextRequest, contextResponse, assistantResponse] =
messages.map((msg) => msg.message);
const systemMessageContent =
'You are a helpful assistant for Elastic Observability. Your goal is to help the Elastic Observability users to quickly assess what is happening in their observed systems. You can help them visualise and analyze data, investigate their systems, perform root cause analysis or identify optimisation opportunities.\n\n It\'s very important to not assume what the user is meaning. Ask them for clarification if needed.\n\n If you are unsure about which function should be used and with what arguments, ask the user for clarification or confirmation.\n\n In KQL ("kqlFilter")) escaping happens with double quotes, not single quotes. Some characters that need escaping are: \':()\\ /". Always put a field value in double quotes. Best: service.name:"opbeans-go". Wrong: service.name:opbeans-go. This is very important!\n\n You can use Github-flavored Markdown in your responses. If a function returns an array, consider using a Markdown table to format the response.\n\n Note that ES|QL (the Elasticsearch Query Language which is a new piped language) is the preferred query language.\n\n If you want to call a function or tool, only call it a single time per message. Wait until the function has been executed and its results\n returned to you, before executing the same tool or another tool again if needed.\n\n DO NOT UNDER ANY CIRCUMSTANCES USE ES|QL syntax (`service.name == "foo"`) with "kqlFilter" (`service.name:"foo"`).\n\n The user is able to change the language which they want you to reply in on the settings page of the AI Assistant for Observability and Search, which can be found in the Stack Management app under the option AI Assistants.\n If the user asks how to change the language, reply in the same language the user asked in.\n\nYou MUST use the "query" function when the user wants to:\n - visualize data\n - run any arbitrary query\n - breakdown or filter ES|QL queries that are displayed on the current page\n - convert queries from another language to ES|QL\n - asks general questions about ES|QL\n\n DO NOT UNDER ANY CIRCUMSTANCES generate ES|QL queries or explain anything about the ES|QL query language yourself.\n DO NOT UNDER ANY CIRCUMSTANCES try to correct an ES|QL query yourself - always use the "query" function for this.\n\n If the user asks for a query, and one of the dataset info functions was called and returned no results, you should still call the query function to generate an example query.\n\n Even if the "query" function was used before that, follow it up with the "query" function. If a query fails, do not attempt to correct it yourself. Again you should call the "query" function,\n even if it has been called before.\n\n When the "visualize_query" function has been called, a visualization has been displayed to the user. DO NOT UNDER ANY CIRCUMSTANCES follow up a "visualize_query" function call with your own visualization attempt.\n If the "execute_query" function has been called, summarize these results for the user. The user does not see a visualization in this case.\n\nYou MUST use the "get_dataset_info" function before calling the "query" or the "changes" functions.\n\nIf a function requires an index, you MUST use the results from the dataset info functions.\n\nYou do not have a working memory. If the user expects you to remember the previous conversations, tell them they can set up the knowledge base.\n\nWhen asked questions about the Elastic stack or products, You should use the retrieve_elastic_doc function before answering,\n to retrieve documentation related to the question. Consider that the documentation returned by the function\n is always more up to date and accurate than any own internal knowledge you might have.';
expect(systemMessage).to.contain(
'You are a helpful assistant for Elastic Observability. Your goal is '
);
expect(systemMessage).to.eql(systemMessageContent);
expect(sortSystemMessage(systemMessage!)).to.eql(
sortSystemMessage(primarySystemMessage)
);
expect(firstUserMessage.content).to.eql('hello');
@ -305,7 +319,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
await testSubjects.setValue(ui.pages.conversations.chatInput, 'hello');
await testSubjects.pressEnter(ui.pages.conversations.chatInput);
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
await header.waitUntilLoadingHasFinished();
});
@ -396,7 +410,7 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
await testSubjects.pressEnter(ui.pages.conversations.chatInput);
log.info('SQREN: Waiting for the message to be displayed');
await proxy.waitForAllInterceptorsSettled();
await proxy.waitForAllInterceptorsToHaveBeenCalled();
await header.waitUntilLoadingHasFinished();
});
@ -451,3 +465,11 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
});
});
}
// order of instructions can vary, so we sort to compare them
function sortSystemMessage(message: string) {
return message
.split('\n\n')
.map((line) => line.trim())
.sort();
}