[Obs AI Assistant] Add test for get_dataset_info (#213231)

- Add API test for `get_dataset_info`
- Add apache synthtrace scenario
- Search local and remote clusters unless otherwise specified
This commit is contained in:
Søren Louv-Jansen 2025-03-07 13:53:10 +01:00 committed by GitHub
parent 2ead636ebd
commit 175e9066d0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 849 additions and 198 deletions

View file

@ -57,6 +57,7 @@ export type LogDocument = Fields &
'cloud.availability_zone'?: string;
'cloud.project.id'?: string;
'cloud.instance.id'?: string;
'client.ip'?: string;
'error.stack_trace'?: string;
'error.exception'?: unknown;
'error.log'?: unknown;
@ -68,6 +69,9 @@ export type LogDocument = Fields &
'event.duration': number;
'event.start': Date;
'event.end': Date;
'event.category'?: string;
'event.type'?: string;
'event.outcome'?: string;
labels?: Record<string, string>;
test_field: string | string[];
date: Date;
@ -76,8 +80,11 @@ export type LogDocument = Fields &
svc: string;
hostname: string;
[LONG_FIELD_NAME]: string;
'http.status_code'?: number;
'http.response.status_code'?: number;
'http.response.bytes'?: number;
'http.request.method'?: string;
'http.request.referrer'?: string;
'http.version'?: string;
'url.path'?: string;
'process.name'?: string;
'kubernetes.namespace'?: string;
@ -85,6 +92,7 @@ export type LogDocument = Fields &
'kubernetes.container.name'?: string;
'orchestrator.resource.name'?: string;
tags?: string | string[];
'user_agent.name'?: string;
}>;
class Log extends Serializable<LogDocument> {

View file

@ -0,0 +1,144 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
import { LogDocument, log } from '@kbn/apm-synthtrace-client';
import moment from 'moment';
import { random } from 'lodash';
import { Scenario } from '../cli/scenario';
import { withClient } from '../lib/utils/with_client';
import { parseLogsScenarioOpts } from './helpers/logs_scenario_opts_parser';
import { IndexTemplateName } from '../lib/logs/custom_logsdb_index_templates';
const scenario: Scenario<LogDocument> = async (runOptions) => {
const { isLogsDb } = parseLogsScenarioOpts(runOptions.scenarioOpts);
return {
bootstrap: async ({ logsEsClient }) => {
if (isLogsDb) await logsEsClient.createIndexTemplate(IndexTemplateName.LogsDb);
},
teardown: async ({ logsEsClient }) => {
if (isLogsDb) await logsEsClient.deleteIndexTemplate(IndexTemplateName.LogsDb);
},
generate: ({ range, clients: { logsEsClient } }) => {
const { logger } = runOptions;
// Normal access logs
const normalAccessLogs = range
.interval('1m')
.rate(50)
.generator((timestamp) => {
return Array(5)
.fill(0)
.map(() => {
const logsData = constructApacheLogData();
return log
.create({ isLogsDb })
.message(
`${logsData['client.ip']} - - [${moment(timestamp).format(
'DD/MMM/YYYY:HH:mm:ss Z'
)}] "${logsData['http.request.method']} ${logsData['url.path']} HTTP/${
logsData['http.version']
}" ${logsData['http.response.status_code']} ${logsData['http.response.bytes']}`
)
.dataset('apache.access')
.defaults(logsData)
.timestamp(timestamp);
});
});
// attack simulation logs
const attackSimulationLogs = range
.interval('1m')
.rate(2)
.generator((timestamp) => {
return Array(2)
.fill(0)
.map(() => {
const logsData = constructApacheLogData();
return log
.create({ isLogsDb })
.message(
`ATTACK SIMULATION: ${logsData['client.ip']} attempted access to restricted path ${logsData['url.path']}`
)
.dataset('apache.security')
.logLevel('warning')
.defaults({
...logsData,
'event.category': 'network',
'event.type': 'access',
'event.outcome': 'failure',
})
.timestamp(timestamp);
});
});
return withClient(
logsEsClient,
logger.perf('generating_apache_logs', () => [normalAccessLogs, attackSimulationLogs])
);
},
};
};
export default scenario;
function constructApacheLogData(): LogDocument {
const APACHE_LOG_SCENARIOS = [
{
method: 'GET',
path: '/index.html',
responseCode: 200,
userAgent: 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36',
referrer: 'https://www.google.com',
},
{
method: 'POST',
path: '/login',
responseCode: 401,
userAgent: 'PostmanRuntime/7.29.0',
referrer: '-',
},
{
method: 'GET',
path: '/admin/dashboard',
responseCode: 403,
userAgent: 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)',
referrer: 'https://example.com/home',
},
];
const HOSTNAMES = ['www.example.com', 'blog.example.com', 'api.example.com'];
const CLOUD_REGIONS = ['us-east-1', 'eu-west-2', 'ap-southeast-1'];
const index = Math.floor(Math.random() * APACHE_LOG_SCENARIOS.length);
const { method, path, responseCode, userAgent, referrer } = APACHE_LOG_SCENARIOS[index];
const clientIp = generateIpAddress();
const hostname = HOSTNAMES[Math.floor(Math.random() * HOSTNAMES.length)];
const cloudRegion = CLOUD_REGIONS[Math.floor(Math.random() * CLOUD_REGIONS.length)];
return {
'http.request.method': method,
'url.path': path,
'http.response.status_code': responseCode,
hostname,
'cloud.region': cloudRegion,
'cloud.availability_zone': `${cloudRegion}a`,
'client.ip': clientIp,
'user_agent.name': userAgent,
'http.request.referrer': referrer,
};
}
function generateIpAddress() {
return `${random(0, 255)}.${random(0, 255)}.${random(0, 255)}.${random(0, 255)}`;
}

View file

@ -16,7 +16,7 @@ import { FunctionCallChatFunction } from '../../service/types';
const SELECT_RELEVANT_FIELDS_NAME = 'select_relevant_fields';
export const GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE = `You are a helpful assistant for Elastic Observability.
Your task is to determine which fields are relevant to the conversation by selecting only the field IDs from the provided list.
The list in the user message consists of JSON objects that map a human-readable "field" name to its unique "id".
The list in the user message consists of JSON objects that map a human-readable field "name" to its unique "id".
You must not output any field names only the corresponding "id" values. Ensure that your output follows the exact JSON format specified.`;
export async function getRelevantFieldNames({
@ -114,10 +114,12 @@ export async function getRelevantFieldNames({
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: `Below is a list of fields. Each entry is a JSON object that contains a "field" (the field name) and an "id" (the unique identifier). Use only the "id" values from this list when selecting relevant fields:
content: `Below is a list of fields. Each entry is a JSON object that contains a "name" (the field name) and an "id" (the unique identifier). Use only the "id" values from this list when selecting relevant fields:
${fieldsInChunk
.map((field) => JSON.stringify({ field, id: shortIdTable.take(field) }))
.map((fieldName) =>
JSON.stringify({ name: fieldName, id: shortIdTable.take(fieldName) })
)
.join('\n')}`,
},
},

View file

@ -5,8 +5,11 @@
* 2.0.
*/
import { IScopedClusterClient, Logger } from '@kbn/core/server';
import { Message } from '../../../common';
import { FunctionRegistrationParameters } from '..';
import { FunctionVisibility } from '../../../common/functions/types';
import { FunctionCallChatFunction, RespondFunctionResources } from '../../service/types';
import { getRelevantFieldNames } from './get_relevant_field_names';
export const GET_DATASET_INFO_FUNCTION_NAME = 'get_dataset_info';
@ -32,68 +35,102 @@ export function registerGetDatasetInfoFunction({
index: {
type: 'string',
description:
'index pattern the user is interested in or empty string to get information about all available indices',
'Please specify the index pattern(s) that are relevant to the user\'s query. You are allowed to specify multiple, comma-separated patterns like "index1,index2". If you provide an empty string, it will search across all indicies in all clusters. By default indicies that match the pattern in both local and remote clusters are searched. If you want to limit the search to a specific cluster you can prefix the index pattern with the cluster name. For example, "cluster1:my-index".',
},
},
required: ['index'],
} as const,
},
async ({ arguments: { index }, messages, chat }, signal) => {
const coreContext = await resources.context.core;
const esClient = coreContext.elasticsearch.client;
const savedObjectsClient = coreContext.savedObjects.client;
let indices: string[] = [];
try {
const body = await esClient.asCurrentUser.indices.resolveIndex({
name: index === '' ? ['*', '*:*'] : index.split(','),
expand_wildcards: 'open',
});
indices = [
...body.indices.map((i) => i.name),
...body.data_streams.map((d) => d.name),
...body.aliases.map((d) => d.name),
];
} catch (e) {
indices = [];
}
if (index === '') {
return {
content: {
indices,
fields: [],
},
};
}
if (indices.length === 0) {
return {
content: {
indices,
fields: [],
},
};
}
const relevantFieldNames = await getRelevantFieldNames({
index,
messages,
esClient: esClient.asCurrentUser,
dataViews: await resources.plugins.dataViews.start(),
savedObjectsClient,
signal,
chat,
});
return {
content: {
indices: [index],
fields: relevantFieldNames.fields,
stats: relevantFieldNames.stats,
},
};
async ({ arguments: { index: indexPattern }, messages, chat }, signal) => {
const content = await getDatasetInfo({ resources, indexPattern, signal, messages, chat });
return { content };
}
);
}
export async function getDatasetInfo({
resources,
indexPattern,
signal,
messages,
chat,
}: {
resources: RespondFunctionResources;
indexPattern: string;
signal: AbortSignal;
messages: Message[];
chat: FunctionCallChatFunction;
}) {
const coreContext = await resources.context.core;
const esClient = coreContext.elasticsearch.client;
const savedObjectsClient = coreContext.savedObjects.client;
const indices = await getIndicesFromIndexPattern(indexPattern, esClient, resources.logger);
if (indices.length === 0 || indexPattern === '') {
return { indices, fields: [] };
}
try {
const { fields, stats } = await getRelevantFieldNames({
index: indices,
messages,
esClient: esClient.asCurrentUser,
dataViews: await resources.plugins.dataViews.start(),
savedObjectsClient,
signal,
chat,
});
return { indices, fields, stats };
} catch (e) {
resources.logger.error(`Error getting relevant field names: ${e.message}`);
return { indices, fields: [] };
}
}
async function getIndicesFromIndexPattern(
indexPattern: string,
esClient: IScopedClusterClient,
logger: Logger
) {
let name: string[] = [];
if (indexPattern === '') {
name = ['*', '*:*'];
} else {
name = indexPattern.split(',').flatMap((pattern) => {
// search specific cluster
if (pattern.includes(':')) {
const [cluster, p] = pattern.split(':');
return `${cluster}:*${p}*`;
}
// search across local and remote clusters
return [`*${pattern}*`, `*:*${pattern}*`];
});
}
try {
const body = await esClient.asCurrentUser.indices.resolveIndex({
name,
expand_wildcards: 'open', // exclude hidden and closed indices
});
// if there is an exact match, only return that
const hasExactMatch =
body.indices.some((i) => i.name === indexPattern) ||
body.aliases.some((i) => i.name === indexPattern);
if (hasExactMatch) {
return [indexPattern];
}
// otherwise return all matching indices, data streams, and aliases
return [
...body.indices.map((i) => i.name),
...body.data_streams.map((d) => d.name),
...body.aliases.map((d) => d.name),
];
} catch (e) {
logger.error(`Error resolving index pattern: ${e.message}`);
return [];
}
}

View file

@ -6,6 +6,7 @@
*/
import { notImplemented } from '@hapi/boom';
import { nonEmptyStringRt, toBooleanRt } from '@kbn/io-ts-utils';
import { context as otelContext } from '@opentelemetry/api';
import * as t from 'io-ts';
import { v4 } from 'uuid';
import { FunctionDefinition } from '../../../common/functions/types';
@ -14,6 +15,8 @@ import type { RecalledEntry } from '../../service/knowledge_base_service';
import { getSystemMessageFromInstructions } from '../../service/util/get_system_message_from_instructions';
import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route';
import { assistantScopeType } from '../runtime_types';
import { getDatasetInfo } from '../../functions/get_dataset_info';
import { LangTracer } from '../../service/client/instrumentation/lang_tracer';
const getFunctionsRoute = createObservabilityAIAssistantServerRoute({
endpoint: 'GET /internal/observability_ai_assistant/functions',
@ -78,6 +81,47 @@ const getFunctionsRoute = createObservabilityAIAssistantServerRoute({
},
});
const functionDatasetInfoRoute = createObservabilityAIAssistantServerRoute({
endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info',
params: t.type({
query: t.type({ index: t.string, connectorId: t.string }),
}),
security: {
authz: {
requiredPrivileges: ['ai_assistant'],
},
},
handler: async (resources) => {
const client = await resources.service.getClient({ request: resources.request });
const {
query: { index, connectorId },
} = resources.params;
const controller = new AbortController();
resources.request.events.aborted$.subscribe(() => {
controller.abort();
});
const resp = await getDatasetInfo({
resources,
indexPattern: index,
signal: controller.signal,
messages: [],
chat: (operationName, params) => {
return client.chat(operationName, {
...params,
stream: true,
tracer: new LangTracer(otelContext.active()),
connectorId,
});
},
});
return resp;
},
});
const functionRecallRoute = createObservabilityAIAssistantServerRoute({
endpoint: 'POST /internal/observability_ai_assistant/functions/recall',
params: t.type({
@ -176,4 +220,5 @@ export const functionRoutes = {
...getFunctionsRoute,
...functionRecallRoute,
...functionSummariseRoute,
...functionDatasetInfoRoute,
};

View file

@ -17,9 +17,10 @@ import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream';
import { ApmAlertFields } from '../../../../../../../apm_api_integration/tests/alerts/helpers/alerting_api_helper';
import {
LlmProxy,
RelevantField,
createLlmProxy,
} from '../../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy';
import { getMessageAddedEvents } from './helpers';
import { chatComplete, getSystemMessage, systemMessageSorted } from './helpers';
import type { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context';
import { APM_ALERTS_INDEX } from '../../../apm/alerts/helpers/alerting_helper';
@ -31,8 +32,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
const alertingApi = getService('alertingApi');
const samlAuth = getService('samlAuth');
describe('function: get_alerts_dataset_info', function () {
// Fails on MKI: https://github.com/elastic/kibana/issues/205581
describe('get_alerts_dataset_info', function () {
this.tags(['failsOnMKI']);
let llmProxy: LlmProxy;
let connectorId: string;
@ -40,8 +40,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
let apmSynthtraceEsClient: ApmSynthtraceEsClient;
let roleAuthc: RoleCredentials;
let createdRuleId: string;
let expectedRelevantFieldNames: string[];
let primarySystemMessage: string;
let getRelevantFields: () => Promise<RelevantField[]>;
before(async () => {
({ apmSynthtraceEsClient } = await createSyntheticApmData(getService));
@ -58,26 +57,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
when: () => true,
});
void llmProxy.interceptWithFunctionRequest({
name: 'select_relevant_fields',
// @ts-expect-error
when: (requestBody) => requestBody.tool_choice?.function?.name === 'select_relevant_fields',
arguments: (requestBody) => {
const userMessage = last(requestBody.messages);
const topFields = (userMessage?.content as string)
.slice(204) // remove the prefix message and only get the JSON
.trim()
.split('\n')
.map((line) => JSON.parse(line))
.slice(0, 5);
expectedRelevantFieldNames = topFields.map(({ field }) => field);
const fieldIds = topFields.map(({ id }) => id);
return JSON.stringify({ fieldIds });
},
});
({ getRelevantFields } = llmProxy.interceptSelectRelevantFieldsToolChoice());
void llmProxy.interceptWithFunctionRequest({
name: 'alerts',
@ -89,44 +69,13 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
`You have active alerts for the past 10 days. Back to work!`
);
const { status, body } = await observabilityAIAssistantAPIClient.editor({
endpoint: 'POST /internal/observability_ai_assistant/chat/complete',
params: {
body: {
messages: [
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: USER_MESSAGE,
},
},
],
connectorId,
persist: false,
screenContexts: [],
scopes: ['observability' as const],
},
},
});
expect(status).to.be(200);
({ messageAddedEvents } = await chatComplete({
userPrompt: USER_MESSAGE,
connectorId,
observabilityAIAssistantAPIClient,
}));
await llmProxy.waitForAllInterceptorsToHaveBeenCalled();
messageAddedEvents = getMessageAddedEvents(body);
const {
body: { systemMessage },
} = await observabilityAIAssistantAPIClient.editor({
endpoint: 'GET /internal/observability_ai_assistant/functions',
params: {
query: {
scopes: ['observability'],
},
},
});
primarySystemMessage = systemMessage;
});
after(async () => {
@ -146,7 +95,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
await samlAuth.invalidateM2mApiKeyWithRoleScope(roleAuthc);
});
describe('LLM requests', () => {
describe('POST /internal/observability_ai_assistant/chat/complete', () => {
let firstRequestBody: ChatCompletionStreamParams;
let secondRequestBody: ChatCompletionStreamParams;
let thirdRequestBody: ChatCompletionStreamParams;
@ -163,6 +112,51 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
expect(llmProxy.interceptedRequests.length).to.be(4);
});
it('emits 7 messageAdded events', () => {
expect(messageAddedEvents.length).to.be(7);
});
it('emits messageAdded events in the correct order', async () => {
const formattedMessageAddedEvents = messageAddedEvents.map(({ message }) => {
const { role, name, function_call: functionCall } = message.message;
if (functionCall) {
return { function_call: functionCall, role };
}
return { name, role };
});
expect(formattedMessageAddedEvents).to.eql([
{
role: 'assistant',
function_call: { name: 'context', trigger: 'assistant' },
},
{ name: 'context', role: 'user' },
{
role: 'assistant',
function_call: {
name: 'get_alerts_dataset_info',
arguments: '{"start":"now-10d","end":"now"}',
trigger: 'assistant',
},
},
{ name: 'get_alerts_dataset_info', role: 'user' },
{
role: 'assistant',
function_call: {
name: 'alerts',
arguments: '{"start":"now-10d","end":"now"}',
trigger: 'assistant',
},
},
{ name: 'alerts', role: 'user' },
{
role: 'assistant',
function_call: { name: '', arguments: '', trigger: 'assistant' },
},
]);
});
describe('every request to the LLM', () => {
it('contains a system message', () => {
const everyRequestHasSystemMessage = llmProxy.interceptedRequests.every(
@ -228,9 +222,10 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
});
describe('The system message', () => {
it('has the primary system message', () => {
expect(sortSystemMessage(firstRequestBody.messages[0].content as string)).to.eql(
sortSystemMessage(primarySystemMessage)
it('has the primary system message', async () => {
const primarySystemMessage = await getSystemMessage(getService);
expect(systemMessageSorted(firstRequestBody.messages[0].content as string)).to.eql(
systemMessageSorted(primarySystemMessage)
);
});
@ -254,14 +249,11 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
});
it('contains a system generated user message with a list of field candidates', () => {
const hasList = secondRequestBody.messages.some(
(message) =>
message.role === 'user' &&
(message.content as string).includes('Below is a list of fields.') &&
(message.content as string).includes('@timestamp')
);
const lastMessage = last(secondRequestBody.messages);
expect(hasList).to.be(true);
expect(lastMessage?.role).to.be('user');
expect(lastMessage?.content).to.contain('Below is a list of fields');
expect(lastMessage?.content).to.contain('@timestamp');
});
it('instructs the LLM to call the `select_relevant_fields` tool via `tool_choice`', () => {
@ -294,7 +286,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
expect(hasFunctionRequest).to.be(true);
});
it('contains the `get_alerts_dataset_info` response', () => {
it('contains the `get_alerts_dataset_info` response', async () => {
const functionResponse = last(thirdRequestBody.messages);
const parsedContent = JSON.parse(functionResponse?.content as string) as {
fields: string[];
@ -303,7 +295,8 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
const fieldNamesWithType = parsedContent.fields;
const fieldNamesWithoutType = fieldNamesWithType.map((field) => field.split(':')[0]);
expect(fieldNamesWithoutType).to.eql(expectedRelevantFieldNames);
const relevantFields = await getRelevantFields();
expect(fieldNamesWithoutType).to.eql(relevantFields.map(({ name }) => name));
expect(fieldNamesWithType).to.eql([
'@timestamp:date',
'_id:_id',
@ -314,13 +307,13 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
});
it('emits a messageAdded event with the `get_alerts_dataset_info` function response', async () => {
const messageWithDatasetInfo = messageAddedEvents.find(
const eventWithDatasetInfo = messageAddedEvents.find(
({ message }) =>
message.message.role === MessageRole.User &&
message.message.name === 'get_alerts_dataset_info'
);
const parsedContent = JSON.parse(messageWithDatasetInfo?.message.message.content!) as {
const parsedContent = JSON.parse(eventWithDatasetInfo?.message.message.content!) as {
fields: string[];
};
@ -361,12 +354,12 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
});
it('emits a messageAdded event with the `alert` function response', async () => {
const messageWithAlerts = messageAddedEvents.find(
const event = messageAddedEvents.find(
({ message }) =>
message.message.role === MessageRole.User && message.message.name === 'alerts'
);
const parsedContent = JSON.parse(messageWithAlerts?.message.message.content!) as {
const parsedContent = JSON.parse(event?.message.message.content!) as {
total: number;
alerts: any[];
};
@ -375,53 +368,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
});
});
});
describe('messageAdded events', () => {
it('emits 7 messageAdded events', () => {
expect(messageAddedEvents.length).to.be(7);
});
it('emits messageAdded events in the correct order', async () => {
const formattedMessageAddedEvents = messageAddedEvents.map(({ message }) => {
const { role, name, function_call: functionCall } = message.message;
if (functionCall) {
return { function_call: functionCall, role };
}
return { name, role };
});
expect(formattedMessageAddedEvents).to.eql([
{
role: 'assistant',
function_call: { name: 'context', trigger: 'assistant' },
},
{ name: 'context', role: 'user' },
{
role: 'assistant',
function_call: {
name: 'get_alerts_dataset_info',
arguments: '{"start":"now-10d","end":"now"}',
trigger: 'assistant',
},
},
{ name: 'get_alerts_dataset_info', role: 'user' },
{
role: 'assistant',
function_call: {
name: 'alerts',
arguments: '{"start":"now-10d","end":"now"}',
trigger: 'assistant',
},
},
{ name: 'alerts', role: 'user' },
{
role: 'assistant',
function_call: { name: '', arguments: '', trigger: 'assistant' },
},
]);
});
});
});
}
@ -490,11 +436,3 @@ async function createSyntheticApmData(
return { apmSynthtraceEsClient };
}
// order of instructions can vary, so we sort to compare them
function sortSystemMessage(message: string) {
return message
.split('\n\n')
.map((line) => line.trim())
.sort();
}

View file

@ -0,0 +1,341 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MessageAddEvent, MessageRole } from '@kbn/observability-ai-assistant-plugin/common';
import expect from '@kbn/expect';
import { LogsSynthtraceEsClient } from '@kbn/apm-synthtrace';
import { last } from 'lodash';
import { GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE } from '@kbn/observability-ai-assistant-plugin/server/functions/get_dataset_info/get_relevant_field_names';
import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream';
import {
LlmProxy,
RelevantField,
createLlmProxy,
} from '../../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy';
import { chatComplete, getSystemMessage, systemMessageSorted } from './helpers';
import type { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context';
import { createSimpleSyntheticLogs } from '../../synthtrace_scenarios/simple_logs';
export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderContext) {
const log = getService('log');
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
const synthtrace = getService('synthtrace');
describe('get_dataset_info', function () {
this.tags(['failsOnMKI']);
let llmProxy: LlmProxy;
let connectorId: string;
before(async () => {
llmProxy = await createLlmProxy(log);
connectorId = await observabilityAIAssistantAPIClient.createProxyActionConnector({
port: llmProxy.getPort(),
});
});
after(async () => {
llmProxy.close();
await observabilityAIAssistantAPIClient.deleteActionConnector({
actionId: connectorId,
});
});
// Calling `get_dataset_info` via the chat/complete endpoint
describe('POST /internal/observability_ai_assistant/chat/complete', function () {
let messageAddedEvents: MessageAddEvent[];
let logsSynthtraceEsClient: LogsSynthtraceEsClient;
let getRelevantFields: () => Promise<RelevantField[]>;
let firstRequestBody: ChatCompletionStreamParams;
let secondRequestBody: ChatCompletionStreamParams;
let thirdRequestBody: ChatCompletionStreamParams;
const USER_MESSAGE = 'Do I have any Apache logs?';
before(async () => {
logsSynthtraceEsClient = synthtrace.createLogsSynthtraceEsClient();
await createSimpleSyntheticLogs({ logsSynthtraceEsClient });
void llmProxy.interceptWithFunctionRequest({
name: 'get_dataset_info',
arguments: () => JSON.stringify({ index: 'logs*' }),
when: () => true,
});
({ getRelevantFields } = llmProxy.interceptSelectRelevantFieldsToolChoice());
void llmProxy.interceptConversation(`Yes, you do have logs. Congratulations! 🎈️🎈️🎈️`);
({ messageAddedEvents } = await chatComplete({
userPrompt: USER_MESSAGE,
connectorId,
observabilityAIAssistantAPIClient,
}));
await llmProxy.waitForAllInterceptorsToHaveBeenCalled();
firstRequestBody = llmProxy.interceptedRequests[0].requestBody;
secondRequestBody = llmProxy.interceptedRequests[1].requestBody;
thirdRequestBody = llmProxy.interceptedRequests[2].requestBody;
});
after(async () => {
await logsSynthtraceEsClient.clean();
});
it('makes 3 requests to the LLM', () => {
expect(llmProxy.interceptedRequests.length).to.be(3);
});
it('emits 5 messageAdded events', () => {
expect(messageAddedEvents.length).to.be(5);
});
describe('every request to the LLM', () => {
it('contains a system message', () => {
const everyRequestHasSystemMessage = llmProxy.interceptedRequests.every(
({ requestBody }) => {
const firstMessage = requestBody.messages[0];
return (
firstMessage.role === 'system' &&
(firstMessage.content as string).includes('You are a helpful assistant')
);
}
);
expect(everyRequestHasSystemMessage).to.be(true);
});
it('contains the original user message', () => {
const everyRequestHasUserMessage = llmProxy.interceptedRequests.every(({ requestBody }) =>
requestBody.messages.some(
(message) => message.role === 'user' && (message.content as string) === USER_MESSAGE
)
);
expect(everyRequestHasUserMessage).to.be(true);
});
it('contains the context function request and context function response', () => {
const everyRequestHasContextFunction = llmProxy.interceptedRequests.every(
({ requestBody }) => {
const hasContextFunctionRequest = requestBody.messages.some(
(message) =>
message.role === 'assistant' &&
message.tool_calls?.[0]?.function?.name === 'context'
);
const hasContextFunctionResponse = requestBody.messages.some(
(message) =>
message.role === 'tool' &&
(message.content as string).includes('screen_description') &&
(message.content as string).includes('learnings')
);
return hasContextFunctionRequest && hasContextFunctionResponse;
}
);
expect(everyRequestHasContextFunction).to.be(true);
});
});
describe('The first request', () => {
it('contains the correct number of messages', () => {
expect(firstRequestBody.messages.length).to.be(4);
});
it('contains the `get_dataset_info` tool', () => {
const hasTool = firstRequestBody.tools?.some(
(tool) => tool.function.name === 'get_dataset_info'
);
expect(hasTool).to.be(true);
});
it('leaves the function calling decision to the LLM via tool_choice=auto', () => {
expect(firstRequestBody.tool_choice).to.be('auto');
});
describe('The system message', () => {
it('has the primary system message', async () => {
const primarySystemMessage = await getSystemMessage(getService);
expect(systemMessageSorted(firstRequestBody.messages[0].content as string)).to.eql(
systemMessageSorted(primarySystemMessage)
);
});
it('has a different system message from request 2', () => {
expect(firstRequestBody.messages[0]).not.to.eql(secondRequestBody.messages[0]);
});
it('has the same system message as request 3', () => {
expect(firstRequestBody.messages[0]).to.eql(thirdRequestBody.messages[0]);
});
});
});
describe('The second request', () => {
it('contains the correct number of messages', () => {
expect(secondRequestBody.messages.length).to.be(5);
});
it('contains a system generated user message with a list of field candidates', () => {
const lastMessage = last(secondRequestBody.messages);
expect(lastMessage?.role).to.be('user');
expect(lastMessage?.content).to.contain('Below is a list of fields');
expect(lastMessage?.content).to.contain('@timestamp');
});
it('instructs the LLM to call the `select_relevant_fields` tool via `tool_choice`', () => {
const hasToolChoice =
// @ts-expect-error
secondRequestBody.tool_choice?.function?.name === 'select_relevant_fields';
expect(hasToolChoice).to.be(true);
});
it('has a custom, function-specific system message', () => {
expect(secondRequestBody.messages[0].content).to.be(
GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE
);
});
});
describe('The third request', () => {
it('contains the correct number of messages', () => {
expect(thirdRequestBody.messages.length).to.be(6);
});
it('contains the `get_dataset_info` request', () => {
const hasFunctionRequest = thirdRequestBody.messages.some(
(message) =>
message.role === 'assistant' &&
message.tool_calls?.[0]?.function?.name === 'get_dataset_info'
);
expect(hasFunctionRequest).to.be(true);
});
it('contains the `get_dataset_info` response', () => {
const functionResponseMessage = last(thirdRequestBody.messages);
const parsedContent = JSON.parse(functionResponseMessage?.content as string);
expect(Object.keys(parsedContent)).to.eql(['indices', 'fields', 'stats']);
expect(parsedContent.indices).to.contain('logs-web.access-default');
});
it('emits a messageAdded event with the `get_dataset_info` function response', async () => {
const event = messageAddedEvents.find(
({ message }) =>
message.message.role === MessageRole.User &&
message.message.name === 'get_dataset_info'
);
const parsedContent = JSON.parse(event?.message.message.content!) as {
indices: string[];
fields: string[];
};
const fieldNamesWithType = parsedContent.fields;
const fieldNamesWithoutType = fieldNamesWithType.map((field) => field.split(':')[0]);
const relevantFields = await getRelevantFields();
expect(fieldNamesWithoutType).to.eql(relevantFields.map(({ name }) => name));
expect(parsedContent.indices).to.contain('logs-web.access-default');
});
});
});
// Calling `get_dataset_info` directly
describe('GET /internal/observability_ai_assistant/functions/get_dataset_info', () => {
let logsSynthtraceEsClient: LogsSynthtraceEsClient;
before(async () => {
logsSynthtraceEsClient = synthtrace.createLogsSynthtraceEsClient();
await Promise.all([
createSimpleSyntheticLogs({ logsSynthtraceEsClient, dataset: 'zookeeper.access' }),
createSimpleSyntheticLogs({ logsSynthtraceEsClient, dataset: 'apache.access' }),
]);
});
after(async () => {
await logsSynthtraceEsClient.clean();
});
it('returns Zookeeper logs but not the Apache logs', async () => {
llmProxy.interceptSelectRelevantFieldsToolChoice({ to: 20 });
const { body } = await observabilityAIAssistantAPIClient.editor({
endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info',
params: {
query: {
index: 'zookeeper',
connectorId,
},
},
});
expect(body.indices).to.eql(['logs-zookeeper.access-default']);
expect(body.fields.length).to.be.greaterThan(0);
});
it('returns both Zookeeper and Apache logs', async () => {
llmProxy.interceptSelectRelevantFieldsToolChoice({ to: 20 });
const { body } = await observabilityAIAssistantAPIClient.editor({
endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info',
params: {
query: {
index: 'logs',
connectorId,
},
},
});
await llmProxy.waitForAllInterceptorsToHaveBeenCalled();
expect(body.indices).to.contain('logs-apache.access-default');
expect(body.indices).to.contain('logs-zookeeper.access-default');
expect(body.fields.length).to.be.greaterThan(0);
});
it('accepts a comma-separated of patterns', async () => {
llmProxy.interceptSelectRelevantFieldsToolChoice({ to: 20 });
const { body } = await observabilityAIAssistantAPIClient.editor({
endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info',
params: {
query: {
index: 'zookeeper,apache',
connectorId,
},
},
});
await llmProxy.waitForAllInterceptorsToHaveBeenCalled();
expect(body.indices).to.eql([
'logs-apache.access-default',
'logs-zookeeper.access-default',
]);
});
it('handles no matching indices gracefully', async () => {
const { body } = await observabilityAIAssistantAPIClient.editor({
endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info',
params: {
query: {
index: 'foobarbaz',
connectorId,
},
},
});
expect(body.indices).to.eql([]);
expect(body.fields).to.eql([]);
});
});
});
}

View file

@ -13,6 +13,7 @@ import {
} from '@kbn/observability-ai-assistant-plugin/common';
import { Readable } from 'stream';
import type { AssistantScope } from '@kbn/ai-assistant-common';
import { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context';
import type { ObservabilityAIAssistantApiClient } from '../../../../../services/observability_ai_assistant_api';
function decodeEvents(body: Readable | string) {
@ -73,3 +74,64 @@ export async function invokeChatCompleteWithFunctionRequest({
return body;
}
export async function chatComplete({
userPrompt,
connectorId,
observabilityAIAssistantAPIClient,
}: {
userPrompt: string;
connectorId: string;
observabilityAIAssistantAPIClient: ObservabilityAIAssistantApiClient;
}) {
const { status, body } = await observabilityAIAssistantAPIClient.editor({
endpoint: 'POST /internal/observability_ai_assistant/chat/complete',
params: {
body: {
messages: [
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: userPrompt,
},
},
],
connectorId,
persist: false,
screenContexts: [],
scopes: ['observability' as const],
},
},
});
expect(status).to.be(200);
const messageAddedEvents = getMessageAddedEvents(body);
return { messageAddedEvents, body, status };
}
// order of instructions can vary, so we sort to compare them
export function systemMessageSorted(message: string) {
return message
.split('\n\n')
.map((line) => line.trim())
.sort();
}
export async function getSystemMessage(
getService: DeploymentAgnosticFtrProviderContext['getService']
) {
const apiClient = getService('observabilityAIAssistantApi');
const { body } = await apiClient.editor({
endpoint: 'GET /internal/observability_ai_assistant/functions',
params: {
query: {
scopes: ['observability'],
},
},
});
return body.systemMessage;
}

View file

@ -17,6 +17,7 @@ export default function aiAssistantApiIntegrationTests({
loadTestFile(require.resolve('./complete/complete.spec.ts'));
loadTestFile(require.resolve('./complete/functions/alerts.spec.ts'));
loadTestFile(require.resolve('./complete/functions/get_alerts_dataset_info.spec.ts'));
loadTestFile(require.resolve('./complete/functions/get_dataset_info.spec.ts'));
loadTestFile(require.resolve('./complete/functions/elasticsearch.spec.ts'));
loadTestFile(require.resolve('./complete/functions/summarize.spec.ts'));
loadTestFile(require.resolve('./public_complete/public_complete.spec.ts'));

View file

@ -0,0 +1,34 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { timerange, log } from '@kbn/apm-synthtrace-client';
import { LogsSynthtraceEsClient } from '@kbn/apm-synthtrace';
export async function createSimpleSyntheticLogs({
logsSynthtraceEsClient,
message,
dataset,
}: {
logsSynthtraceEsClient: LogsSynthtraceEsClient;
message?: string;
dataset?: string;
}) {
const range = timerange('now-15m', 'now');
const simpleLogs = range
.interval('1m')
.rate(1)
.generator((timestamp) =>
log
.create()
.message(message ?? 'simple log message')
.dataset(dataset ?? 'web.access')
.timestamp(timestamp)
);
await logsSynthtraceEsClient.index([simpleLogs]);
}

View file

@ -13,7 +13,7 @@ interface GetApmSynthtraceEsClientParams {
packageVersion: string;
}
export async function getApmSynthtraceEsClient({
export function getApmSynthtraceEsClient({
client,
packageVersion,
}: GetApmSynthtraceEsClientParams) {

View file

@ -8,7 +8,7 @@
import { Client } from '@elastic/elasticsearch';
import { InfraSynthtraceEsClient, createLogger, LogLevel } from '@kbn/apm-synthtrace';
export async function getInfraSynthtraceEsClient(client: Client) {
export function getInfraSynthtraceEsClient(client: Client) {
return new InfraSynthtraceEsClient({
client,
logger: createLogger(LogLevel.info),

View file

@ -8,7 +8,7 @@
import { Client } from '@elastic/elasticsearch';
import { LogsSynthtraceEsClient, createLogger, LogLevel } from '@kbn/apm-synthtrace';
export async function getLogsSynthtraceEsClient(client: Client) {
export function getLogsSynthtraceEsClient(client: Client) {
return new LogsSynthtraceEsClient({
client,
logger: createLogger(LogLevel.info),

View file

@ -9,7 +9,7 @@ import { ToolingLog } from '@kbn/tooling-log';
import getPort from 'get-port';
import { v4 as uuidv4 } from 'uuid';
import http, { type Server } from 'http';
import { isString, once, pull, isFunction } from 'lodash';
import { isString, once, pull, isFunction, last } from 'lodash';
import { TITLE_CONVERSATION_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/service/client/operators/get_generated_title';
import pRetry from 'p-retry';
import type { ChatCompletionChunkToolCall } from '@kbn/inference-common';
@ -36,6 +36,12 @@ export interface ToolMessage {
content?: string;
tool_calls?: ChatCompletionChunkToolCall[];
}
export interface RelevantField {
id: string;
name: string;
}
export interface LlmResponseSimulator {
requestBody: ChatCompletionStreamParams;
status: (code: number) => void;
@ -180,6 +186,39 @@ export class LlmProxy {
}).completeAfterIntercept();
}
interceptSelectRelevantFieldsToolChoice({
from = 0,
to = 5,
}: { from?: number; to?: number } = {}) {
let relevantFields: RelevantField[] = [];
const simulator = this.interceptWithFunctionRequest({
name: 'select_relevant_fields',
// @ts-expect-error
when: (requestBody) => requestBody.tool_choice?.function?.name === 'select_relevant_fields',
arguments: (requestBody) => {
const messageWithFieldIds = last(requestBody.messages);
relevantFields = (messageWithFieldIds?.content as string)
.split('\n\n')
.slice(1)
.join('')
.trim()
.split('\n')
.slice(from, to)
.map((line) => JSON.parse(line) as RelevantField);
return JSON.stringify({ fieldIds: relevantFields.map(({ id }) => id) });
},
});
return {
simulator,
getRelevantFields: async () => {
await simulator;
return relevantFields;
},
};
}
interceptTitle(title: string) {
return this.interceptWithFunctionRequest({
name: TITLE_CONVERSATION_FUNCTION_NAME,