mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 09:19:04 -04:00
# Backport This will backport the following commits from `main` to `8.x`: - [[Obs AI Assistant] Add test for `get_dataset_info` (#213231)](https://github.com/elastic/kibana/pull/213231) <!--- Backport version: 9.6.6 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sorenlouv/backport) <!--BACKPORT [{"author":{"name":"Søren Louv-Jansen","email":"soren.louv@elastic.co"},"sourceCommit":{"committedDate":"2025-03-07T12:53:10Z","message":"[Obs AI Assistant] Add test for `get_dataset_info` (#213231)\n\n- Add API test for `get_dataset_info`\n- Add apache synthtrace scenario\n- Search local and remote clusters unless otherwise specified","sha":"175e9066d00a76985bf956b2cf693fb0319b9940","branchLabelMapping":{"^v9.1.0$":"main","^v8.19.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v9.0.0","Team:Obs AI Assistant","ci:project-deploy-observability","Team:obs-ux-infra_services","backport:version","v9.1.0","v8.19.0"],"title":"[Obs AI Assistant] Add test for `get_dataset_info`","number":213231,"url":"https://github.com/elastic/kibana/pull/213231","mergeCommit":{"message":"[Obs AI Assistant] Add test for `get_dataset_info` (#213231)\n\n- Add API test for `get_dataset_info`\n- Add apache synthtrace scenario\n- Search local and remote clusters unless otherwise specified","sha":"175e9066d00a76985bf956b2cf693fb0319b9940"}},"sourceBranch":"main","suggestedTargetBranches":["9.0","8.x"],"targetPullRequestStates":[{"branch":"9.0","label":"v9.0.0","branchLabelMappingKey":"^v(\\d+).(\\d+).\\d+$","isSourceBranch":false,"state":"NOT_CREATED"},{"branch":"main","label":"v9.1.0","branchLabelMappingKey":"^v9.1.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/213231","number":213231,"mergeCommit":{"message":"[Obs AI Assistant] Add test for `get_dataset_info` (#213231)\n\n- Add API test for `get_dataset_info`\n- Add apache synthtrace scenario\n- Search local and remote clusters unless otherwise specified","sha":"175e9066d00a76985bf956b2cf693fb0319b9940"}},{"branch":"8.x","label":"v8.19.0","branchLabelMappingKey":"^v8.19.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}] BACKPORT--> Co-authored-by: Søren Louv-Jansen <soren.louv@elastic.co>
This commit is contained in:
parent
3892ad8c0e
commit
b414e1f3b0
14 changed files with 849 additions and 198 deletions
|
@ -57,6 +57,7 @@ export type LogDocument = Fields &
|
|||
'cloud.availability_zone'?: string;
|
||||
'cloud.project.id'?: string;
|
||||
'cloud.instance.id'?: string;
|
||||
'client.ip'?: string;
|
||||
'error.stack_trace'?: string;
|
||||
'error.exception'?: unknown;
|
||||
'error.log'?: unknown;
|
||||
|
@ -68,6 +69,9 @@ export type LogDocument = Fields &
|
|||
'event.duration': number;
|
||||
'event.start': Date;
|
||||
'event.end': Date;
|
||||
'event.category'?: string;
|
||||
'event.type'?: string;
|
||||
'event.outcome'?: string;
|
||||
labels?: Record<string, string>;
|
||||
test_field: string | string[];
|
||||
date: Date;
|
||||
|
@ -76,8 +80,11 @@ export type LogDocument = Fields &
|
|||
svc: string;
|
||||
hostname: string;
|
||||
[LONG_FIELD_NAME]: string;
|
||||
'http.status_code'?: number;
|
||||
'http.response.status_code'?: number;
|
||||
'http.response.bytes'?: number;
|
||||
'http.request.method'?: string;
|
||||
'http.request.referrer'?: string;
|
||||
'http.version'?: string;
|
||||
'url.path'?: string;
|
||||
'process.name'?: string;
|
||||
'kubernetes.namespace'?: string;
|
||||
|
@ -85,6 +92,7 @@ export type LogDocument = Fields &
|
|||
'kubernetes.container.name'?: string;
|
||||
'orchestrator.resource.name'?: string;
|
||||
tags?: string | string[];
|
||||
'user_agent.name'?: string;
|
||||
}>;
|
||||
|
||||
class Log extends Serializable<LogDocument> {
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the "Elastic License
|
||||
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
|
||||
* Public License v 1"; you may not use this file except in compliance with, at
|
||||
* your election, the "Elastic License 2.0", the "GNU Affero General Public
|
||||
* License v3.0 only", or the "Server Side Public License, v 1".
|
||||
*/
|
||||
|
||||
import { LogDocument, log } from '@kbn/apm-synthtrace-client';
|
||||
import moment from 'moment';
|
||||
import { random } from 'lodash';
|
||||
import { Scenario } from '../cli/scenario';
|
||||
import { withClient } from '../lib/utils/with_client';
|
||||
import { parseLogsScenarioOpts } from './helpers/logs_scenario_opts_parser';
|
||||
import { IndexTemplateName } from '../lib/logs/custom_logsdb_index_templates';
|
||||
|
||||
const scenario: Scenario<LogDocument> = async (runOptions) => {
|
||||
const { isLogsDb } = parseLogsScenarioOpts(runOptions.scenarioOpts);
|
||||
|
||||
return {
|
||||
bootstrap: async ({ logsEsClient }) => {
|
||||
if (isLogsDb) await logsEsClient.createIndexTemplate(IndexTemplateName.LogsDb);
|
||||
},
|
||||
teardown: async ({ logsEsClient }) => {
|
||||
if (isLogsDb) await logsEsClient.deleteIndexTemplate(IndexTemplateName.LogsDb);
|
||||
},
|
||||
|
||||
generate: ({ range, clients: { logsEsClient } }) => {
|
||||
const { logger } = runOptions;
|
||||
|
||||
// Normal access logs
|
||||
const normalAccessLogs = range
|
||||
.interval('1m')
|
||||
.rate(50)
|
||||
.generator((timestamp) => {
|
||||
return Array(5)
|
||||
.fill(0)
|
||||
.map(() => {
|
||||
const logsData = constructApacheLogData();
|
||||
|
||||
return log
|
||||
.create({ isLogsDb })
|
||||
.message(
|
||||
`${logsData['client.ip']} - - [${moment(timestamp).format(
|
||||
'DD/MMM/YYYY:HH:mm:ss Z'
|
||||
)}] "${logsData['http.request.method']} ${logsData['url.path']} HTTP/${
|
||||
logsData['http.version']
|
||||
}" ${logsData['http.response.status_code']} ${logsData['http.response.bytes']}`
|
||||
)
|
||||
.dataset('apache.access')
|
||||
.defaults(logsData)
|
||||
.timestamp(timestamp);
|
||||
});
|
||||
});
|
||||
|
||||
// attack simulation logs
|
||||
const attackSimulationLogs = range
|
||||
.interval('1m')
|
||||
.rate(2)
|
||||
.generator((timestamp) => {
|
||||
return Array(2)
|
||||
.fill(0)
|
||||
.map(() => {
|
||||
const logsData = constructApacheLogData();
|
||||
|
||||
return log
|
||||
.create({ isLogsDb })
|
||||
.message(
|
||||
`ATTACK SIMULATION: ${logsData['client.ip']} attempted access to restricted path ${logsData['url.path']}`
|
||||
)
|
||||
.dataset('apache.security')
|
||||
.logLevel('warning')
|
||||
.defaults({
|
||||
...logsData,
|
||||
'event.category': 'network',
|
||||
'event.type': 'access',
|
||||
'event.outcome': 'failure',
|
||||
})
|
||||
.timestamp(timestamp);
|
||||
});
|
||||
});
|
||||
|
||||
return withClient(
|
||||
logsEsClient,
|
||||
logger.perf('generating_apache_logs', () => [normalAccessLogs, attackSimulationLogs])
|
||||
);
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
export default scenario;
|
||||
|
||||
function constructApacheLogData(): LogDocument {
|
||||
const APACHE_LOG_SCENARIOS = [
|
||||
{
|
||||
method: 'GET',
|
||||
path: '/index.html',
|
||||
responseCode: 200,
|
||||
userAgent: 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36',
|
||||
referrer: 'https://www.google.com',
|
||||
},
|
||||
{
|
||||
method: 'POST',
|
||||
path: '/login',
|
||||
responseCode: 401,
|
||||
userAgent: 'PostmanRuntime/7.29.0',
|
||||
referrer: '-',
|
||||
},
|
||||
{
|
||||
method: 'GET',
|
||||
path: '/admin/dashboard',
|
||||
responseCode: 403,
|
||||
userAgent: 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7)',
|
||||
referrer: 'https://example.com/home',
|
||||
},
|
||||
];
|
||||
|
||||
const HOSTNAMES = ['www.example.com', 'blog.example.com', 'api.example.com'];
|
||||
const CLOUD_REGIONS = ['us-east-1', 'eu-west-2', 'ap-southeast-1'];
|
||||
|
||||
const index = Math.floor(Math.random() * APACHE_LOG_SCENARIOS.length);
|
||||
const { method, path, responseCode, userAgent, referrer } = APACHE_LOG_SCENARIOS[index];
|
||||
|
||||
const clientIp = generateIpAddress();
|
||||
const hostname = HOSTNAMES[Math.floor(Math.random() * HOSTNAMES.length)];
|
||||
const cloudRegion = CLOUD_REGIONS[Math.floor(Math.random() * CLOUD_REGIONS.length)];
|
||||
|
||||
return {
|
||||
'http.request.method': method,
|
||||
'url.path': path,
|
||||
'http.response.status_code': responseCode,
|
||||
hostname,
|
||||
'cloud.region': cloudRegion,
|
||||
'cloud.availability_zone': `${cloudRegion}a`,
|
||||
'client.ip': clientIp,
|
||||
'user_agent.name': userAgent,
|
||||
'http.request.referrer': referrer,
|
||||
};
|
||||
}
|
||||
|
||||
function generateIpAddress() {
|
||||
return `${random(0, 255)}.${random(0, 255)}.${random(0, 255)}.${random(0, 255)}`;
|
||||
}
|
|
@ -16,7 +16,7 @@ import { FunctionCallChatFunction } from '../../service/types';
|
|||
const SELECT_RELEVANT_FIELDS_NAME = 'select_relevant_fields';
|
||||
export const GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE = `You are a helpful assistant for Elastic Observability.
|
||||
Your task is to determine which fields are relevant to the conversation by selecting only the field IDs from the provided list.
|
||||
The list in the user message consists of JSON objects that map a human-readable "field" name to its unique "id".
|
||||
The list in the user message consists of JSON objects that map a human-readable field "name" to its unique "id".
|
||||
You must not output any field names — only the corresponding "id" values. Ensure that your output follows the exact JSON format specified.`;
|
||||
|
||||
export async function getRelevantFieldNames({
|
||||
|
@ -114,10 +114,12 @@ export async function getRelevantFieldNames({
|
|||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
content: `Below is a list of fields. Each entry is a JSON object that contains a "field" (the field name) and an "id" (the unique identifier). Use only the "id" values from this list when selecting relevant fields:
|
||||
content: `Below is a list of fields. Each entry is a JSON object that contains a "name" (the field name) and an "id" (the unique identifier). Use only the "id" values from this list when selecting relevant fields:
|
||||
|
||||
${fieldsInChunk
|
||||
.map((field) => JSON.stringify({ field, id: shortIdTable.take(field) }))
|
||||
.map((fieldName) =>
|
||||
JSON.stringify({ name: fieldName, id: shortIdTable.take(fieldName) })
|
||||
)
|
||||
.join('\n')}`,
|
||||
},
|
||||
},
|
||||
|
|
|
@ -5,8 +5,11 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { IScopedClusterClient, Logger } from '@kbn/core/server';
|
||||
import { Message } from '../../../common';
|
||||
import { FunctionRegistrationParameters } from '..';
|
||||
import { FunctionVisibility } from '../../../common/functions/types';
|
||||
import { FunctionCallChatFunction, RespondFunctionResources } from '../../service/types';
|
||||
import { getRelevantFieldNames } from './get_relevant_field_names';
|
||||
|
||||
export const GET_DATASET_INFO_FUNCTION_NAME = 'get_dataset_info';
|
||||
|
@ -32,68 +35,102 @@ export function registerGetDatasetInfoFunction({
|
|||
index: {
|
||||
type: 'string',
|
||||
description:
|
||||
'index pattern the user is interested in or empty string to get information about all available indices',
|
||||
'Please specify the index pattern(s) that are relevant to the user\'s query. You are allowed to specify multiple, comma-separated patterns like "index1,index2". If you provide an empty string, it will search across all indicies in all clusters. By default indicies that match the pattern in both local and remote clusters are searched. If you want to limit the search to a specific cluster you can prefix the index pattern with the cluster name. For example, "cluster1:my-index".',
|
||||
},
|
||||
},
|
||||
required: ['index'],
|
||||
} as const,
|
||||
},
|
||||
async ({ arguments: { index }, messages, chat }, signal) => {
|
||||
const coreContext = await resources.context.core;
|
||||
|
||||
const esClient = coreContext.elasticsearch.client;
|
||||
const savedObjectsClient = coreContext.savedObjects.client;
|
||||
|
||||
let indices: string[] = [];
|
||||
|
||||
try {
|
||||
const body = await esClient.asCurrentUser.indices.resolveIndex({
|
||||
name: index === '' ? ['*', '*:*'] : index.split(','),
|
||||
expand_wildcards: 'open',
|
||||
});
|
||||
indices = [
|
||||
...body.indices.map((i) => i.name),
|
||||
...body.data_streams.map((d) => d.name),
|
||||
...body.aliases.map((d) => d.name),
|
||||
];
|
||||
} catch (e) {
|
||||
indices = [];
|
||||
}
|
||||
|
||||
if (index === '') {
|
||||
return {
|
||||
content: {
|
||||
indices,
|
||||
fields: [],
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
if (indices.length === 0) {
|
||||
return {
|
||||
content: {
|
||||
indices,
|
||||
fields: [],
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const relevantFieldNames = await getRelevantFieldNames({
|
||||
index,
|
||||
messages,
|
||||
esClient: esClient.asCurrentUser,
|
||||
dataViews: await resources.plugins.dataViews.start(),
|
||||
savedObjectsClient,
|
||||
signal,
|
||||
chat,
|
||||
});
|
||||
return {
|
||||
content: {
|
||||
indices: [index],
|
||||
fields: relevantFieldNames.fields,
|
||||
stats: relevantFieldNames.stats,
|
||||
},
|
||||
};
|
||||
async ({ arguments: { index: indexPattern }, messages, chat }, signal) => {
|
||||
const content = await getDatasetInfo({ resources, indexPattern, signal, messages, chat });
|
||||
return { content };
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
export async function getDatasetInfo({
|
||||
resources,
|
||||
indexPattern,
|
||||
signal,
|
||||
messages,
|
||||
chat,
|
||||
}: {
|
||||
resources: RespondFunctionResources;
|
||||
indexPattern: string;
|
||||
signal: AbortSignal;
|
||||
messages: Message[];
|
||||
chat: FunctionCallChatFunction;
|
||||
}) {
|
||||
const coreContext = await resources.context.core;
|
||||
const esClient = coreContext.elasticsearch.client;
|
||||
const savedObjectsClient = coreContext.savedObjects.client;
|
||||
|
||||
const indices = await getIndicesFromIndexPattern(indexPattern, esClient, resources.logger);
|
||||
if (indices.length === 0 || indexPattern === '') {
|
||||
return { indices, fields: [] };
|
||||
}
|
||||
|
||||
try {
|
||||
const { fields, stats } = await getRelevantFieldNames({
|
||||
index: indices,
|
||||
messages,
|
||||
esClient: esClient.asCurrentUser,
|
||||
dataViews: await resources.plugins.dataViews.start(),
|
||||
savedObjectsClient,
|
||||
signal,
|
||||
chat,
|
||||
});
|
||||
return { indices, fields, stats };
|
||||
} catch (e) {
|
||||
resources.logger.error(`Error getting relevant field names: ${e.message}`);
|
||||
return { indices, fields: [] };
|
||||
}
|
||||
}
|
||||
|
||||
async function getIndicesFromIndexPattern(
|
||||
indexPattern: string,
|
||||
esClient: IScopedClusterClient,
|
||||
logger: Logger
|
||||
) {
|
||||
let name: string[] = [];
|
||||
if (indexPattern === '') {
|
||||
name = ['*', '*:*'];
|
||||
} else {
|
||||
name = indexPattern.split(',').flatMap((pattern) => {
|
||||
// search specific cluster
|
||||
if (pattern.includes(':')) {
|
||||
const [cluster, p] = pattern.split(':');
|
||||
return `${cluster}:*${p}*`;
|
||||
}
|
||||
|
||||
// search across local and remote clusters
|
||||
return [`*${pattern}*`, `*:*${pattern}*`];
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
const body = await esClient.asCurrentUser.indices.resolveIndex({
|
||||
name,
|
||||
expand_wildcards: 'open', // exclude hidden and closed indices
|
||||
});
|
||||
|
||||
// if there is an exact match, only return that
|
||||
const hasExactMatch =
|
||||
body.indices.some((i) => i.name === indexPattern) ||
|
||||
body.aliases.some((i) => i.name === indexPattern);
|
||||
|
||||
if (hasExactMatch) {
|
||||
return [indexPattern];
|
||||
}
|
||||
|
||||
// otherwise return all matching indices, data streams, and aliases
|
||||
return [
|
||||
...body.indices.map((i) => i.name),
|
||||
...body.data_streams.map((d) => d.name),
|
||||
...body.aliases.map((d) => d.name),
|
||||
];
|
||||
} catch (e) {
|
||||
logger.error(`Error resolving index pattern: ${e.message}`);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
*/
|
||||
import { notImplemented } from '@hapi/boom';
|
||||
import { nonEmptyStringRt, toBooleanRt } from '@kbn/io-ts-utils';
|
||||
import { context as otelContext } from '@opentelemetry/api';
|
||||
import * as t from 'io-ts';
|
||||
import { v4 } from 'uuid';
|
||||
import { FunctionDefinition } from '../../../common/functions/types';
|
||||
|
@ -14,6 +15,8 @@ import type { RecalledEntry } from '../../service/knowledge_base_service';
|
|||
import { getSystemMessageFromInstructions } from '../../service/util/get_system_message_from_instructions';
|
||||
import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route';
|
||||
import { assistantScopeType } from '../runtime_types';
|
||||
import { getDatasetInfo } from '../../functions/get_dataset_info';
|
||||
import { LangTracer } from '../../service/client/instrumentation/lang_tracer';
|
||||
|
||||
const getFunctionsRoute = createObservabilityAIAssistantServerRoute({
|
||||
endpoint: 'GET /internal/observability_ai_assistant/functions',
|
||||
|
@ -78,6 +81,47 @@ const getFunctionsRoute = createObservabilityAIAssistantServerRoute({
|
|||
},
|
||||
});
|
||||
|
||||
const functionDatasetInfoRoute = createObservabilityAIAssistantServerRoute({
|
||||
endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info',
|
||||
params: t.type({
|
||||
query: t.type({ index: t.string, connectorId: t.string }),
|
||||
}),
|
||||
security: {
|
||||
authz: {
|
||||
requiredPrivileges: ['ai_assistant'],
|
||||
},
|
||||
},
|
||||
handler: async (resources) => {
|
||||
const client = await resources.service.getClient({ request: resources.request });
|
||||
|
||||
const {
|
||||
query: { index, connectorId },
|
||||
} = resources.params;
|
||||
|
||||
const controller = new AbortController();
|
||||
resources.request.events.aborted$.subscribe(() => {
|
||||
controller.abort();
|
||||
});
|
||||
|
||||
const resp = await getDatasetInfo({
|
||||
resources,
|
||||
indexPattern: index,
|
||||
signal: controller.signal,
|
||||
messages: [],
|
||||
chat: (operationName, params) => {
|
||||
return client.chat(operationName, {
|
||||
...params,
|
||||
stream: true,
|
||||
tracer: new LangTracer(otelContext.active()),
|
||||
connectorId,
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
return resp;
|
||||
},
|
||||
});
|
||||
|
||||
const functionRecallRoute = createObservabilityAIAssistantServerRoute({
|
||||
endpoint: 'POST /internal/observability_ai_assistant/functions/recall',
|
||||
params: t.type({
|
||||
|
@ -176,4 +220,5 @@ export const functionRoutes = {
|
|||
...getFunctionsRoute,
|
||||
...functionRecallRoute,
|
||||
...functionSummariseRoute,
|
||||
...functionDatasetInfoRoute,
|
||||
};
|
||||
|
|
|
@ -17,9 +17,10 @@ import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream';
|
|||
import { ApmAlertFields } from '../../../../../../../apm_api_integration/tests/alerts/helpers/alerting_api_helper';
|
||||
import {
|
||||
LlmProxy,
|
||||
RelevantField,
|
||||
createLlmProxy,
|
||||
} from '../../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy';
|
||||
import { getMessageAddedEvents } from './helpers';
|
||||
import { chatComplete, getSystemMessage, systemMessageSorted } from './helpers';
|
||||
import type { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context';
|
||||
import { APM_ALERTS_INDEX } from '../../../apm/alerts/helpers/alerting_helper';
|
||||
|
||||
|
@ -31,8 +32,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
const alertingApi = getService('alertingApi');
|
||||
const samlAuth = getService('samlAuth');
|
||||
|
||||
describe('function: get_alerts_dataset_info', function () {
|
||||
// Fails on MKI: https://github.com/elastic/kibana/issues/205581
|
||||
describe('get_alerts_dataset_info', function () {
|
||||
this.tags(['failsOnMKI']);
|
||||
let llmProxy: LlmProxy;
|
||||
let connectorId: string;
|
||||
|
@ -40,8 +40,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
let apmSynthtraceEsClient: ApmSynthtraceEsClient;
|
||||
let roleAuthc: RoleCredentials;
|
||||
let createdRuleId: string;
|
||||
let expectedRelevantFieldNames: string[];
|
||||
let primarySystemMessage: string;
|
||||
let getRelevantFields: () => Promise<RelevantField[]>;
|
||||
|
||||
before(async () => {
|
||||
({ apmSynthtraceEsClient } = await createSyntheticApmData(getService));
|
||||
|
@ -58,26 +57,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
when: () => true,
|
||||
});
|
||||
|
||||
void llmProxy.interceptWithFunctionRequest({
|
||||
name: 'select_relevant_fields',
|
||||
// @ts-expect-error
|
||||
when: (requestBody) => requestBody.tool_choice?.function?.name === 'select_relevant_fields',
|
||||
arguments: (requestBody) => {
|
||||
const userMessage = last(requestBody.messages);
|
||||
const topFields = (userMessage?.content as string)
|
||||
.slice(204) // remove the prefix message and only get the JSON
|
||||
.trim()
|
||||
.split('\n')
|
||||
.map((line) => JSON.parse(line))
|
||||
.slice(0, 5);
|
||||
|
||||
expectedRelevantFieldNames = topFields.map(({ field }) => field);
|
||||
|
||||
const fieldIds = topFields.map(({ id }) => id);
|
||||
|
||||
return JSON.stringify({ fieldIds });
|
||||
},
|
||||
});
|
||||
({ getRelevantFields } = llmProxy.interceptSelectRelevantFieldsToolChoice());
|
||||
|
||||
void llmProxy.interceptWithFunctionRequest({
|
||||
name: 'alerts',
|
||||
|
@ -89,44 +69,13 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
`You have active alerts for the past 10 days. Back to work!`
|
||||
);
|
||||
|
||||
const { status, body } = await observabilityAIAssistantAPIClient.editor({
|
||||
endpoint: 'POST /internal/observability_ai_assistant/chat/complete',
|
||||
params: {
|
||||
body: {
|
||||
messages: [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
content: USER_MESSAGE,
|
||||
},
|
||||
},
|
||||
],
|
||||
connectorId,
|
||||
persist: false,
|
||||
screenContexts: [],
|
||||
scopes: ['observability' as const],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(status).to.be(200);
|
||||
({ messageAddedEvents } = await chatComplete({
|
||||
userPrompt: USER_MESSAGE,
|
||||
connectorId,
|
||||
observabilityAIAssistantAPIClient,
|
||||
}));
|
||||
|
||||
await llmProxy.waitForAllInterceptorsToHaveBeenCalled();
|
||||
messageAddedEvents = getMessageAddedEvents(body);
|
||||
|
||||
const {
|
||||
body: { systemMessage },
|
||||
} = await observabilityAIAssistantAPIClient.editor({
|
||||
endpoint: 'GET /internal/observability_ai_assistant/functions',
|
||||
params: {
|
||||
query: {
|
||||
scopes: ['observability'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
primarySystemMessage = systemMessage;
|
||||
});
|
||||
|
||||
after(async () => {
|
||||
|
@ -146,7 +95,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
await samlAuth.invalidateM2mApiKeyWithRoleScope(roleAuthc);
|
||||
});
|
||||
|
||||
describe('LLM requests', () => {
|
||||
describe('POST /internal/observability_ai_assistant/chat/complete', () => {
|
||||
let firstRequestBody: ChatCompletionStreamParams;
|
||||
let secondRequestBody: ChatCompletionStreamParams;
|
||||
let thirdRequestBody: ChatCompletionStreamParams;
|
||||
|
@ -163,6 +112,51 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
expect(llmProxy.interceptedRequests.length).to.be(4);
|
||||
});
|
||||
|
||||
it('emits 7 messageAdded events', () => {
|
||||
expect(messageAddedEvents.length).to.be(7);
|
||||
});
|
||||
|
||||
it('emits messageAdded events in the correct order', async () => {
|
||||
const formattedMessageAddedEvents = messageAddedEvents.map(({ message }) => {
|
||||
const { role, name, function_call: functionCall } = message.message;
|
||||
if (functionCall) {
|
||||
return { function_call: functionCall, role };
|
||||
}
|
||||
|
||||
return { name, role };
|
||||
});
|
||||
|
||||
expect(formattedMessageAddedEvents).to.eql([
|
||||
{
|
||||
role: 'assistant',
|
||||
function_call: { name: 'context', trigger: 'assistant' },
|
||||
},
|
||||
{ name: 'context', role: 'user' },
|
||||
{
|
||||
role: 'assistant',
|
||||
function_call: {
|
||||
name: 'get_alerts_dataset_info',
|
||||
arguments: '{"start":"now-10d","end":"now"}',
|
||||
trigger: 'assistant',
|
||||
},
|
||||
},
|
||||
{ name: 'get_alerts_dataset_info', role: 'user' },
|
||||
{
|
||||
role: 'assistant',
|
||||
function_call: {
|
||||
name: 'alerts',
|
||||
arguments: '{"start":"now-10d","end":"now"}',
|
||||
trigger: 'assistant',
|
||||
},
|
||||
},
|
||||
{ name: 'alerts', role: 'user' },
|
||||
{
|
||||
role: 'assistant',
|
||||
function_call: { name: '', arguments: '', trigger: 'assistant' },
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
describe('every request to the LLM', () => {
|
||||
it('contains a system message', () => {
|
||||
const everyRequestHasSystemMessage = llmProxy.interceptedRequests.every(
|
||||
|
@ -228,9 +222,10 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
});
|
||||
|
||||
describe('The system message', () => {
|
||||
it('has the primary system message', () => {
|
||||
expect(sortSystemMessage(firstRequestBody.messages[0].content as string)).to.eql(
|
||||
sortSystemMessage(primarySystemMessage)
|
||||
it('has the primary system message', async () => {
|
||||
const primarySystemMessage = await getSystemMessage(getService);
|
||||
expect(systemMessageSorted(firstRequestBody.messages[0].content as string)).to.eql(
|
||||
systemMessageSorted(primarySystemMessage)
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -254,14 +249,11 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
});
|
||||
|
||||
it('contains a system generated user message with a list of field candidates', () => {
|
||||
const hasList = secondRequestBody.messages.some(
|
||||
(message) =>
|
||||
message.role === 'user' &&
|
||||
(message.content as string).includes('Below is a list of fields.') &&
|
||||
(message.content as string).includes('@timestamp')
|
||||
);
|
||||
const lastMessage = last(secondRequestBody.messages);
|
||||
|
||||
expect(hasList).to.be(true);
|
||||
expect(lastMessage?.role).to.be('user');
|
||||
expect(lastMessage?.content).to.contain('Below is a list of fields');
|
||||
expect(lastMessage?.content).to.contain('@timestamp');
|
||||
});
|
||||
|
||||
it('instructs the LLM to call the `select_relevant_fields` tool via `tool_choice`', () => {
|
||||
|
@ -294,7 +286,7 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
expect(hasFunctionRequest).to.be(true);
|
||||
});
|
||||
|
||||
it('contains the `get_alerts_dataset_info` response', () => {
|
||||
it('contains the `get_alerts_dataset_info` response', async () => {
|
||||
const functionResponse = last(thirdRequestBody.messages);
|
||||
const parsedContent = JSON.parse(functionResponse?.content as string) as {
|
||||
fields: string[];
|
||||
|
@ -303,7 +295,8 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
const fieldNamesWithType = parsedContent.fields;
|
||||
const fieldNamesWithoutType = fieldNamesWithType.map((field) => field.split(':')[0]);
|
||||
|
||||
expect(fieldNamesWithoutType).to.eql(expectedRelevantFieldNames);
|
||||
const relevantFields = await getRelevantFields();
|
||||
expect(fieldNamesWithoutType).to.eql(relevantFields.map(({ name }) => name));
|
||||
expect(fieldNamesWithType).to.eql([
|
||||
'@timestamp:date',
|
||||
'_id:_id',
|
||||
|
@ -314,13 +307,13 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
});
|
||||
|
||||
it('emits a messageAdded event with the `get_alerts_dataset_info` function response', async () => {
|
||||
const messageWithDatasetInfo = messageAddedEvents.find(
|
||||
const eventWithDatasetInfo = messageAddedEvents.find(
|
||||
({ message }) =>
|
||||
message.message.role === MessageRole.User &&
|
||||
message.message.name === 'get_alerts_dataset_info'
|
||||
);
|
||||
|
||||
const parsedContent = JSON.parse(messageWithDatasetInfo?.message.message.content!) as {
|
||||
const parsedContent = JSON.parse(eventWithDatasetInfo?.message.message.content!) as {
|
||||
fields: string[];
|
||||
};
|
||||
|
||||
|
@ -361,12 +354,12 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
});
|
||||
|
||||
it('emits a messageAdded event with the `alert` function response', async () => {
|
||||
const messageWithAlerts = messageAddedEvents.find(
|
||||
const event = messageAddedEvents.find(
|
||||
({ message }) =>
|
||||
message.message.role === MessageRole.User && message.message.name === 'alerts'
|
||||
);
|
||||
|
||||
const parsedContent = JSON.parse(messageWithAlerts?.message.message.content!) as {
|
||||
const parsedContent = JSON.parse(event?.message.message.content!) as {
|
||||
total: number;
|
||||
alerts: any[];
|
||||
};
|
||||
|
@ -375,53 +368,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('messageAdded events', () => {
|
||||
it('emits 7 messageAdded events', () => {
|
||||
expect(messageAddedEvents.length).to.be(7);
|
||||
});
|
||||
|
||||
it('emits messageAdded events in the correct order', async () => {
|
||||
const formattedMessageAddedEvents = messageAddedEvents.map(({ message }) => {
|
||||
const { role, name, function_call: functionCall } = message.message;
|
||||
if (functionCall) {
|
||||
return { function_call: functionCall, role };
|
||||
}
|
||||
|
||||
return { name, role };
|
||||
});
|
||||
|
||||
expect(formattedMessageAddedEvents).to.eql([
|
||||
{
|
||||
role: 'assistant',
|
||||
function_call: { name: 'context', trigger: 'assistant' },
|
||||
},
|
||||
{ name: 'context', role: 'user' },
|
||||
{
|
||||
role: 'assistant',
|
||||
function_call: {
|
||||
name: 'get_alerts_dataset_info',
|
||||
arguments: '{"start":"now-10d","end":"now"}',
|
||||
trigger: 'assistant',
|
||||
},
|
||||
},
|
||||
{ name: 'get_alerts_dataset_info', role: 'user' },
|
||||
{
|
||||
role: 'assistant',
|
||||
function_call: {
|
||||
name: 'alerts',
|
||||
arguments: '{"start":"now-10d","end":"now"}',
|
||||
trigger: 'assistant',
|
||||
},
|
||||
},
|
||||
{ name: 'alerts', role: 'user' },
|
||||
{
|
||||
role: 'assistant',
|
||||
function_call: { name: '', arguments: '', trigger: 'assistant' },
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -490,11 +436,3 @@ async function createSyntheticApmData(
|
|||
|
||||
return { apmSynthtraceEsClient };
|
||||
}
|
||||
|
||||
// order of instructions can vary, so we sort to compare them
|
||||
function sortSystemMessage(message: string) {
|
||||
return message
|
||||
.split('\n\n')
|
||||
.map((line) => line.trim())
|
||||
.sort();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,341 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { MessageAddEvent, MessageRole } from '@kbn/observability-ai-assistant-plugin/common';
|
||||
import expect from '@kbn/expect';
|
||||
import { LogsSynthtraceEsClient } from '@kbn/apm-synthtrace';
|
||||
import { last } from 'lodash';
|
||||
import { GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE } from '@kbn/observability-ai-assistant-plugin/server/functions/get_dataset_info/get_relevant_field_names';
|
||||
import { ChatCompletionStreamParams } from 'openai/lib/ChatCompletionStream';
|
||||
import {
|
||||
LlmProxy,
|
||||
RelevantField,
|
||||
createLlmProxy,
|
||||
} from '../../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy';
|
||||
import { chatComplete, getSystemMessage, systemMessageSorted } from './helpers';
|
||||
import type { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context';
|
||||
import { createSimpleSyntheticLogs } from '../../synthtrace_scenarios/simple_logs';
|
||||
|
||||
export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderContext) {
|
||||
const log = getService('log');
|
||||
const observabilityAIAssistantAPIClient = getService('observabilityAIAssistantApi');
|
||||
const synthtrace = getService('synthtrace');
|
||||
|
||||
describe('get_dataset_info', function () {
|
||||
this.tags(['failsOnMKI']);
|
||||
let llmProxy: LlmProxy;
|
||||
let connectorId: string;
|
||||
|
||||
before(async () => {
|
||||
llmProxy = await createLlmProxy(log);
|
||||
connectorId = await observabilityAIAssistantAPIClient.createProxyActionConnector({
|
||||
port: llmProxy.getPort(),
|
||||
});
|
||||
});
|
||||
|
||||
after(async () => {
|
||||
llmProxy.close();
|
||||
await observabilityAIAssistantAPIClient.deleteActionConnector({
|
||||
actionId: connectorId,
|
||||
});
|
||||
});
|
||||
|
||||
// Calling `get_dataset_info` via the chat/complete endpoint
|
||||
describe('POST /internal/observability_ai_assistant/chat/complete', function () {
|
||||
let messageAddedEvents: MessageAddEvent[];
|
||||
let logsSynthtraceEsClient: LogsSynthtraceEsClient;
|
||||
let getRelevantFields: () => Promise<RelevantField[]>;
|
||||
let firstRequestBody: ChatCompletionStreamParams;
|
||||
let secondRequestBody: ChatCompletionStreamParams;
|
||||
let thirdRequestBody: ChatCompletionStreamParams;
|
||||
|
||||
const USER_MESSAGE = 'Do I have any Apache logs?';
|
||||
|
||||
before(async () => {
|
||||
logsSynthtraceEsClient = synthtrace.createLogsSynthtraceEsClient();
|
||||
await createSimpleSyntheticLogs({ logsSynthtraceEsClient });
|
||||
|
||||
void llmProxy.interceptWithFunctionRequest({
|
||||
name: 'get_dataset_info',
|
||||
arguments: () => JSON.stringify({ index: 'logs*' }),
|
||||
when: () => true,
|
||||
});
|
||||
|
||||
({ getRelevantFields } = llmProxy.interceptSelectRelevantFieldsToolChoice());
|
||||
|
||||
void llmProxy.interceptConversation(`Yes, you do have logs. Congratulations! 🎈️🎈️🎈️`);
|
||||
|
||||
({ messageAddedEvents } = await chatComplete({
|
||||
userPrompt: USER_MESSAGE,
|
||||
connectorId,
|
||||
observabilityAIAssistantAPIClient,
|
||||
}));
|
||||
|
||||
await llmProxy.waitForAllInterceptorsToHaveBeenCalled();
|
||||
|
||||
firstRequestBody = llmProxy.interceptedRequests[0].requestBody;
|
||||
secondRequestBody = llmProxy.interceptedRequests[1].requestBody;
|
||||
thirdRequestBody = llmProxy.interceptedRequests[2].requestBody;
|
||||
});
|
||||
|
||||
after(async () => {
|
||||
await logsSynthtraceEsClient.clean();
|
||||
});
|
||||
|
||||
it('makes 3 requests to the LLM', () => {
|
||||
expect(llmProxy.interceptedRequests.length).to.be(3);
|
||||
});
|
||||
|
||||
it('emits 5 messageAdded events', () => {
|
||||
expect(messageAddedEvents.length).to.be(5);
|
||||
});
|
||||
|
||||
describe('every request to the LLM', () => {
|
||||
it('contains a system message', () => {
|
||||
const everyRequestHasSystemMessage = llmProxy.interceptedRequests.every(
|
||||
({ requestBody }) => {
|
||||
const firstMessage = requestBody.messages[0];
|
||||
return (
|
||||
firstMessage.role === 'system' &&
|
||||
(firstMessage.content as string).includes('You are a helpful assistant')
|
||||
);
|
||||
}
|
||||
);
|
||||
expect(everyRequestHasSystemMessage).to.be(true);
|
||||
});
|
||||
|
||||
it('contains the original user message', () => {
|
||||
const everyRequestHasUserMessage = llmProxy.interceptedRequests.every(({ requestBody }) =>
|
||||
requestBody.messages.some(
|
||||
(message) => message.role === 'user' && (message.content as string) === USER_MESSAGE
|
||||
)
|
||||
);
|
||||
expect(everyRequestHasUserMessage).to.be(true);
|
||||
});
|
||||
|
||||
it('contains the context function request and context function response', () => {
|
||||
const everyRequestHasContextFunction = llmProxy.interceptedRequests.every(
|
||||
({ requestBody }) => {
|
||||
const hasContextFunctionRequest = requestBody.messages.some(
|
||||
(message) =>
|
||||
message.role === 'assistant' &&
|
||||
message.tool_calls?.[0]?.function?.name === 'context'
|
||||
);
|
||||
|
||||
const hasContextFunctionResponse = requestBody.messages.some(
|
||||
(message) =>
|
||||
message.role === 'tool' &&
|
||||
(message.content as string).includes('screen_description') &&
|
||||
(message.content as string).includes('learnings')
|
||||
);
|
||||
|
||||
return hasContextFunctionRequest && hasContextFunctionResponse;
|
||||
}
|
||||
);
|
||||
|
||||
expect(everyRequestHasContextFunction).to.be(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('The first request', () => {
|
||||
it('contains the correct number of messages', () => {
|
||||
expect(firstRequestBody.messages.length).to.be(4);
|
||||
});
|
||||
|
||||
it('contains the `get_dataset_info` tool', () => {
|
||||
const hasTool = firstRequestBody.tools?.some(
|
||||
(tool) => tool.function.name === 'get_dataset_info'
|
||||
);
|
||||
|
||||
expect(hasTool).to.be(true);
|
||||
});
|
||||
|
||||
it('leaves the function calling decision to the LLM via tool_choice=auto', () => {
|
||||
expect(firstRequestBody.tool_choice).to.be('auto');
|
||||
});
|
||||
|
||||
describe('The system message', () => {
|
||||
it('has the primary system message', async () => {
|
||||
const primarySystemMessage = await getSystemMessage(getService);
|
||||
expect(systemMessageSorted(firstRequestBody.messages[0].content as string)).to.eql(
|
||||
systemMessageSorted(primarySystemMessage)
|
||||
);
|
||||
});
|
||||
|
||||
it('has a different system message from request 2', () => {
|
||||
expect(firstRequestBody.messages[0]).not.to.eql(secondRequestBody.messages[0]);
|
||||
});
|
||||
|
||||
it('has the same system message as request 3', () => {
|
||||
expect(firstRequestBody.messages[0]).to.eql(thirdRequestBody.messages[0]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('The second request', () => {
|
||||
it('contains the correct number of messages', () => {
|
||||
expect(secondRequestBody.messages.length).to.be(5);
|
||||
});
|
||||
|
||||
it('contains a system generated user message with a list of field candidates', () => {
|
||||
const lastMessage = last(secondRequestBody.messages);
|
||||
|
||||
expect(lastMessage?.role).to.be('user');
|
||||
expect(lastMessage?.content).to.contain('Below is a list of fields');
|
||||
expect(lastMessage?.content).to.contain('@timestamp');
|
||||
});
|
||||
|
||||
it('instructs the LLM to call the `select_relevant_fields` tool via `tool_choice`', () => {
|
||||
const hasToolChoice =
|
||||
// @ts-expect-error
|
||||
secondRequestBody.tool_choice?.function?.name === 'select_relevant_fields';
|
||||
|
||||
expect(hasToolChoice).to.be(true);
|
||||
});
|
||||
|
||||
it('has a custom, function-specific system message', () => {
|
||||
expect(secondRequestBody.messages[0].content).to.be(
|
||||
GET_RELEVANT_FIELD_NAMES_SYSTEM_MESSAGE
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('The third request', () => {
|
||||
it('contains the correct number of messages', () => {
|
||||
expect(thirdRequestBody.messages.length).to.be(6);
|
||||
});
|
||||
|
||||
it('contains the `get_dataset_info` request', () => {
|
||||
const hasFunctionRequest = thirdRequestBody.messages.some(
|
||||
(message) =>
|
||||
message.role === 'assistant' &&
|
||||
message.tool_calls?.[0]?.function?.name === 'get_dataset_info'
|
||||
);
|
||||
|
||||
expect(hasFunctionRequest).to.be(true);
|
||||
});
|
||||
|
||||
it('contains the `get_dataset_info` response', () => {
|
||||
const functionResponseMessage = last(thirdRequestBody.messages);
|
||||
const parsedContent = JSON.parse(functionResponseMessage?.content as string);
|
||||
expect(Object.keys(parsedContent)).to.eql(['indices', 'fields', 'stats']);
|
||||
expect(parsedContent.indices).to.contain('logs-web.access-default');
|
||||
});
|
||||
|
||||
it('emits a messageAdded event with the `get_dataset_info` function response', async () => {
|
||||
const event = messageAddedEvents.find(
|
||||
({ message }) =>
|
||||
message.message.role === MessageRole.User &&
|
||||
message.message.name === 'get_dataset_info'
|
||||
);
|
||||
|
||||
const parsedContent = JSON.parse(event?.message.message.content!) as {
|
||||
indices: string[];
|
||||
fields: string[];
|
||||
};
|
||||
|
||||
const fieldNamesWithType = parsedContent.fields;
|
||||
const fieldNamesWithoutType = fieldNamesWithType.map((field) => field.split(':')[0]);
|
||||
|
||||
const relevantFields = await getRelevantFields();
|
||||
expect(fieldNamesWithoutType).to.eql(relevantFields.map(({ name }) => name));
|
||||
expect(parsedContent.indices).to.contain('logs-web.access-default');
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Calling `get_dataset_info` directly
|
||||
describe('GET /internal/observability_ai_assistant/functions/get_dataset_info', () => {
|
||||
let logsSynthtraceEsClient: LogsSynthtraceEsClient;
|
||||
|
||||
before(async () => {
|
||||
logsSynthtraceEsClient = synthtrace.createLogsSynthtraceEsClient();
|
||||
await Promise.all([
|
||||
createSimpleSyntheticLogs({ logsSynthtraceEsClient, dataset: 'zookeeper.access' }),
|
||||
createSimpleSyntheticLogs({ logsSynthtraceEsClient, dataset: 'apache.access' }),
|
||||
]);
|
||||
});
|
||||
|
||||
after(async () => {
|
||||
await logsSynthtraceEsClient.clean();
|
||||
});
|
||||
|
||||
it('returns Zookeeper logs but not the Apache logs', async () => {
|
||||
llmProxy.interceptSelectRelevantFieldsToolChoice({ to: 20 });
|
||||
|
||||
const { body } = await observabilityAIAssistantAPIClient.editor({
|
||||
endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info',
|
||||
params: {
|
||||
query: {
|
||||
index: 'zookeeper',
|
||||
connectorId,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(body.indices).to.eql(['logs-zookeeper.access-default']);
|
||||
expect(body.fields.length).to.be.greaterThan(0);
|
||||
});
|
||||
|
||||
it('returns both Zookeeper and Apache logs', async () => {
|
||||
llmProxy.interceptSelectRelevantFieldsToolChoice({ to: 20 });
|
||||
|
||||
const { body } = await observabilityAIAssistantAPIClient.editor({
|
||||
endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info',
|
||||
params: {
|
||||
query: {
|
||||
index: 'logs',
|
||||
connectorId,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await llmProxy.waitForAllInterceptorsToHaveBeenCalled();
|
||||
|
||||
expect(body.indices).to.contain('logs-apache.access-default');
|
||||
expect(body.indices).to.contain('logs-zookeeper.access-default');
|
||||
expect(body.fields.length).to.be.greaterThan(0);
|
||||
});
|
||||
|
||||
it('accepts a comma-separated of patterns', async () => {
|
||||
llmProxy.interceptSelectRelevantFieldsToolChoice({ to: 20 });
|
||||
|
||||
const { body } = await observabilityAIAssistantAPIClient.editor({
|
||||
endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info',
|
||||
params: {
|
||||
query: {
|
||||
index: 'zookeeper,apache',
|
||||
connectorId,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
await llmProxy.waitForAllInterceptorsToHaveBeenCalled();
|
||||
|
||||
expect(body.indices).to.eql([
|
||||
'logs-apache.access-default',
|
||||
'logs-zookeeper.access-default',
|
||||
]);
|
||||
});
|
||||
|
||||
it('handles no matching indices gracefully', async () => {
|
||||
const { body } = await observabilityAIAssistantAPIClient.editor({
|
||||
endpoint: 'GET /internal/observability_ai_assistant/functions/get_dataset_info',
|
||||
params: {
|
||||
query: {
|
||||
index: 'foobarbaz',
|
||||
connectorId,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(body.indices).to.eql([]);
|
||||
expect(body.fields).to.eql([]);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
|
@ -13,6 +13,7 @@ import {
|
|||
} from '@kbn/observability-ai-assistant-plugin/common';
|
||||
import { Readable } from 'stream';
|
||||
import type { AssistantScope } from '@kbn/ai-assistant-common';
|
||||
import { DeploymentAgnosticFtrProviderContext } from '../../../../../ftr_provider_context';
|
||||
import type { ObservabilityAIAssistantApiClient } from '../../../../../services/observability_ai_assistant_api';
|
||||
|
||||
function decodeEvents(body: Readable | string) {
|
||||
|
@ -73,3 +74,64 @@ export async function invokeChatCompleteWithFunctionRequest({
|
|||
|
||||
return body;
|
||||
}
|
||||
|
||||
export async function chatComplete({
|
||||
userPrompt,
|
||||
connectorId,
|
||||
observabilityAIAssistantAPIClient,
|
||||
}: {
|
||||
userPrompt: string;
|
||||
connectorId: string;
|
||||
observabilityAIAssistantAPIClient: ObservabilityAIAssistantApiClient;
|
||||
}) {
|
||||
const { status, body } = await observabilityAIAssistantAPIClient.editor({
|
||||
endpoint: 'POST /internal/observability_ai_assistant/chat/complete',
|
||||
params: {
|
||||
body: {
|
||||
messages: [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
content: userPrompt,
|
||||
},
|
||||
},
|
||||
],
|
||||
connectorId,
|
||||
persist: false,
|
||||
screenContexts: [],
|
||||
scopes: ['observability' as const],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
expect(status).to.be(200);
|
||||
const messageAddedEvents = getMessageAddedEvents(body);
|
||||
|
||||
return { messageAddedEvents, body, status };
|
||||
}
|
||||
|
||||
// order of instructions can vary, so we sort to compare them
|
||||
export function systemMessageSorted(message: string) {
|
||||
return message
|
||||
.split('\n\n')
|
||||
.map((line) => line.trim())
|
||||
.sort();
|
||||
}
|
||||
|
||||
export async function getSystemMessage(
|
||||
getService: DeploymentAgnosticFtrProviderContext['getService']
|
||||
) {
|
||||
const apiClient = getService('observabilityAIAssistantApi');
|
||||
|
||||
const { body } = await apiClient.editor({
|
||||
endpoint: 'GET /internal/observability_ai_assistant/functions',
|
||||
params: {
|
||||
query: {
|
||||
scopes: ['observability'],
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
return body.systemMessage;
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ export default function aiAssistantApiIntegrationTests({
|
|||
loadTestFile(require.resolve('./complete/complete.spec.ts'));
|
||||
loadTestFile(require.resolve('./complete/functions/alerts.spec.ts'));
|
||||
loadTestFile(require.resolve('./complete/functions/get_alerts_dataset_info.spec.ts'));
|
||||
loadTestFile(require.resolve('./complete/functions/get_dataset_info.spec.ts'));
|
||||
loadTestFile(require.resolve('./complete/functions/elasticsearch.spec.ts'));
|
||||
loadTestFile(require.resolve('./complete/functions/summarize.spec.ts'));
|
||||
loadTestFile(require.resolve('./public_complete/public_complete.spec.ts'));
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { timerange, log } from '@kbn/apm-synthtrace-client';
|
||||
import { LogsSynthtraceEsClient } from '@kbn/apm-synthtrace';
|
||||
|
||||
export async function createSimpleSyntheticLogs({
|
||||
logsSynthtraceEsClient,
|
||||
message,
|
||||
dataset,
|
||||
}: {
|
||||
logsSynthtraceEsClient: LogsSynthtraceEsClient;
|
||||
message?: string;
|
||||
dataset?: string;
|
||||
}) {
|
||||
const range = timerange('now-15m', 'now');
|
||||
|
||||
const simpleLogs = range
|
||||
.interval('1m')
|
||||
.rate(1)
|
||||
.generator((timestamp) =>
|
||||
log
|
||||
.create()
|
||||
.message(message ?? 'simple log message')
|
||||
.dataset(dataset ?? 'web.access')
|
||||
.timestamp(timestamp)
|
||||
);
|
||||
|
||||
await logsSynthtraceEsClient.index([simpleLogs]);
|
||||
}
|
|
@ -13,7 +13,7 @@ interface GetApmSynthtraceEsClientParams {
|
|||
packageVersion: string;
|
||||
}
|
||||
|
||||
export async function getApmSynthtraceEsClient({
|
||||
export function getApmSynthtraceEsClient({
|
||||
client,
|
||||
packageVersion,
|
||||
}: GetApmSynthtraceEsClientParams) {
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
import { Client } from '@elastic/elasticsearch';
|
||||
import { InfraSynthtraceEsClient, createLogger, LogLevel } from '@kbn/apm-synthtrace';
|
||||
|
||||
export async function getInfraSynthtraceEsClient(client: Client) {
|
||||
export function getInfraSynthtraceEsClient(client: Client) {
|
||||
return new InfraSynthtraceEsClient({
|
||||
client,
|
||||
logger: createLogger(LogLevel.info),
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
import { Client } from '@elastic/elasticsearch';
|
||||
import { LogsSynthtraceEsClient, createLogger, LogLevel } from '@kbn/apm-synthtrace';
|
||||
|
||||
export async function getLogsSynthtraceEsClient(client: Client) {
|
||||
export function getLogsSynthtraceEsClient(client: Client) {
|
||||
return new LogsSynthtraceEsClient({
|
||||
client,
|
||||
logger: createLogger(LogLevel.info),
|
||||
|
|
|
@ -9,7 +9,7 @@ import { ToolingLog } from '@kbn/tooling-log';
|
|||
import getPort from 'get-port';
|
||||
import { v4 as uuidv4 } from 'uuid';
|
||||
import http, { type Server } from 'http';
|
||||
import { isString, once, pull, isFunction } from 'lodash';
|
||||
import { isString, once, pull, isFunction, last } from 'lodash';
|
||||
import { TITLE_CONVERSATION_FUNCTION_NAME } from '@kbn/observability-ai-assistant-plugin/server/service/client/operators/get_generated_title';
|
||||
import pRetry from 'p-retry';
|
||||
import type { ChatCompletionChunkToolCall } from '@kbn/inference-common';
|
||||
|
@ -36,6 +36,12 @@ export interface ToolMessage {
|
|||
content?: string;
|
||||
tool_calls?: ChatCompletionChunkToolCall[];
|
||||
}
|
||||
|
||||
export interface RelevantField {
|
||||
id: string;
|
||||
name: string;
|
||||
}
|
||||
|
||||
export interface LlmResponseSimulator {
|
||||
requestBody: ChatCompletionStreamParams;
|
||||
status: (code: number) => void;
|
||||
|
@ -180,6 +186,39 @@ export class LlmProxy {
|
|||
}).completeAfterIntercept();
|
||||
}
|
||||
|
||||
interceptSelectRelevantFieldsToolChoice({
|
||||
from = 0,
|
||||
to = 5,
|
||||
}: { from?: number; to?: number } = {}) {
|
||||
let relevantFields: RelevantField[] = [];
|
||||
const simulator = this.interceptWithFunctionRequest({
|
||||
name: 'select_relevant_fields',
|
||||
// @ts-expect-error
|
||||
when: (requestBody) => requestBody.tool_choice?.function?.name === 'select_relevant_fields',
|
||||
arguments: (requestBody) => {
|
||||
const messageWithFieldIds = last(requestBody.messages);
|
||||
relevantFields = (messageWithFieldIds?.content as string)
|
||||
.split('\n\n')
|
||||
.slice(1)
|
||||
.join('')
|
||||
.trim()
|
||||
.split('\n')
|
||||
.slice(from, to)
|
||||
.map((line) => JSON.parse(line) as RelevantField);
|
||||
|
||||
return JSON.stringify({ fieldIds: relevantFields.map(({ id }) => id) });
|
||||
},
|
||||
});
|
||||
|
||||
return {
|
||||
simulator,
|
||||
getRelevantFields: async () => {
|
||||
await simulator;
|
||||
return relevantFields;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
interceptTitle(title: string) {
|
||||
return this.interceptWithFunctionRequest({
|
||||
name: TITLE_CONVERSATION_FUNCTION_NAME,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue