mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 09:48:58 -04:00
[Obs AI Assistant] Add APM instrumentation for chat/complete (#175799)
Adds APM instrumentation for function calls and interactions with the LLM. Closes https://github.com/elastic/obs-ai-assistant-team/issues/120
This commit is contained in:
parent
ff3c1af8ee
commit
441ed1b662
10 changed files with 150 additions and 88 deletions
|
@ -57,7 +57,7 @@ describe('createChatService', () => {
|
|||
}
|
||||
|
||||
function chat({ signal }: { signal: AbortSignal } = { signal: new AbortController().signal }) {
|
||||
return service.chat({ signal, messages: [], connectorId: '' });
|
||||
return service.chat('my_test', { signal, messages: [], connectorId: '' });
|
||||
}
|
||||
|
||||
beforeEach(async () => {
|
||||
|
|
|
@ -174,17 +174,20 @@ export async function createChatService({
|
|||
});
|
||||
});
|
||||
},
|
||||
chat({
|
||||
connectorId,
|
||||
messages,
|
||||
function: callFunctions = 'auto',
|
||||
signal,
|
||||
}: {
|
||||
connectorId: string;
|
||||
messages: Message[];
|
||||
function?: 'none' | 'auto';
|
||||
signal: AbortSignal;
|
||||
}) {
|
||||
chat(
|
||||
name: string,
|
||||
{
|
||||
connectorId,
|
||||
messages,
|
||||
function: callFunctions = 'auto',
|
||||
signal,
|
||||
}: {
|
||||
connectorId: string;
|
||||
messages: Message[];
|
||||
function?: 'none' | 'auto';
|
||||
signal: AbortSignal;
|
||||
}
|
||||
) {
|
||||
return new Observable<StreamingChatResponseEventWithoutError>((subscriber) => {
|
||||
const contexts = ['core', 'apm'];
|
||||
|
||||
|
@ -193,6 +196,7 @@ export async function createChatService({
|
|||
client('POST /internal/observability_ai_assistant/chat', {
|
||||
params: {
|
||||
body: {
|
||||
name,
|
||||
messages,
|
||||
connectorId,
|
||||
functions:
|
||||
|
|
|
@ -50,12 +50,15 @@ export type { PendingMessage };
|
|||
|
||||
export interface ObservabilityAIAssistantChatService {
|
||||
analytics: AnalyticsServiceStart;
|
||||
chat: (options: {
|
||||
messages: Message[];
|
||||
connectorId: string;
|
||||
function?: 'none' | 'auto';
|
||||
signal: AbortSignal;
|
||||
}) => Observable<StreamingChatResponseEventWithoutError>;
|
||||
chat: (
|
||||
name: string,
|
||||
options: {
|
||||
messages: Message[];
|
||||
connectorId: string;
|
||||
function?: 'none' | 'auto';
|
||||
signal: AbortSignal;
|
||||
}
|
||||
) => Observable<StreamingChatResponseEventWithoutError>;
|
||||
complete: (options: {
|
||||
messages: Message[];
|
||||
connectorId: string;
|
||||
|
|
|
@ -186,17 +186,21 @@ export class KibanaClient {
|
|||
unregister: () => void;
|
||||
}> = [];
|
||||
|
||||
async function chat({
|
||||
messages,
|
||||
functions,
|
||||
functionCall,
|
||||
}: {
|
||||
messages: Message[];
|
||||
functions: FunctionDefinition[];
|
||||
functionCall?: string;
|
||||
}) {
|
||||
async function chat(
|
||||
name: string,
|
||||
{
|
||||
messages,
|
||||
functions,
|
||||
functionCall,
|
||||
}: {
|
||||
messages: Message[];
|
||||
functions: FunctionDefinition[];
|
||||
functionCall?: string;
|
||||
}
|
||||
) {
|
||||
const params: ObservabilityAIAssistantAPIClientRequestParamsOf<'POST /internal/observability_ai_assistant/chat'>['params']['body'] =
|
||||
{
|
||||
name,
|
||||
messages,
|
||||
connectorId,
|
||||
functions: functions.map((fn) => pick(fn, 'name', 'description', 'parameters')),
|
||||
|
@ -235,7 +239,7 @@ export class KibanaClient {
|
|||
'@timestamp': new Date().toISOString(),
|
||||
})),
|
||||
];
|
||||
return chat({ messages, functions: functionDefinitions });
|
||||
return chat('chat', { messages, functions: functionDefinitions });
|
||||
},
|
||||
complete: async (...args) => {
|
||||
const messagesArg = args.length === 1 ? args[0] : args[1];
|
||||
|
@ -298,7 +302,7 @@ export class KibanaClient {
|
|||
};
|
||||
},
|
||||
evaluate: async ({ messages, conversationId }, criteria) => {
|
||||
const message = await chat({
|
||||
const message = await chat('evaluate', {
|
||||
messages: [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
|
|
|
@ -124,7 +124,7 @@ export function registerEsqlFunction({
|
|||
];
|
||||
|
||||
const source$ = (
|
||||
await client.chat({
|
||||
await client.chat('classify_esql', {
|
||||
connectorId,
|
||||
messages: withEsqlSystemMessage(
|
||||
`Use the classify_esql function to classify the user's request
|
||||
|
@ -198,10 +198,12 @@ export function registerEsqlFunction({
|
|||
|
||||
const messagesToInclude = mapValues(pick(esqlDocs, keywords), ({ data }) => data);
|
||||
|
||||
const esqlResponse$: Observable<ChatCompletionChunkEvent> = await client.chat({
|
||||
messages: [
|
||||
...withEsqlSystemMessage(
|
||||
`Format every ES|QL query as Markdown:
|
||||
const esqlResponse$: Observable<ChatCompletionChunkEvent> = await client.chat(
|
||||
'answer_esql_question',
|
||||
{
|
||||
messages: [
|
||||
...withEsqlSystemMessage(
|
||||
`Format every ES|QL query as Markdown:
|
||||
\`\`\`esql
|
||||
<query>
|
||||
\`\`\`
|
||||
|
@ -224,33 +226,34 @@ export function registerEsqlFunction({
|
|||
\`\`\`
|
||||
|
||||
`
|
||||
),
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.Assistant,
|
||||
content: '',
|
||||
function_call: {
|
||||
name: 'get_esql_info',
|
||||
arguments: JSON.stringify(args),
|
||||
trigger: MessageRole.Assistant as const,
|
||||
),
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.Assistant,
|
||||
content: '',
|
||||
function_call: {
|
||||
name: 'get_esql_info',
|
||||
arguments: JSON.stringify(args),
|
||||
trigger: MessageRole.Assistant as const,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
name: 'get_esql_info',
|
||||
content: JSON.stringify({
|
||||
documentation: messagesToInclude,
|
||||
}),
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
message: {
|
||||
role: MessageRole.User,
|
||||
name: 'get_esql_info',
|
||||
content: JSON.stringify({
|
||||
documentation: messagesToInclude,
|
||||
}),
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
connectorId,
|
||||
signal,
|
||||
});
|
||||
],
|
||||
connectorId,
|
||||
signal,
|
||||
}
|
||||
);
|
||||
|
||||
return esqlResponse$.pipe(
|
||||
emitWithConcatenatedMessage((msg) => {
|
||||
|
|
|
@ -115,7 +115,7 @@ export function registerGetDatasetInfoFunction({
|
|||
const relevantFields = await Promise.all(
|
||||
chunk(fieldNames, 500).map(async (fieldsInChunk) => {
|
||||
const chunkResponse$ = (
|
||||
await client.chat({
|
||||
await client.chat('get_relevent_dataset_names', {
|
||||
connectorId,
|
||||
signal,
|
||||
messages: [
|
||||
|
|
|
@ -248,7 +248,7 @@ async function scoreSuggestions({
|
|||
|
||||
const response = await lastValueFrom(
|
||||
(
|
||||
await client.chat({
|
||||
await client.chat('score_suggestions', {
|
||||
connectorId,
|
||||
messages: [extendedSystemMessage, newUserMessage],
|
||||
functions: [scoreFunction],
|
||||
|
|
|
@ -21,6 +21,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
|
|||
params: t.type({
|
||||
body: t.intersection([
|
||||
t.type({
|
||||
name: t.string,
|
||||
messages: t.array(messageRt),
|
||||
connectorId: t.string,
|
||||
functions: t.array(
|
||||
|
@ -46,7 +47,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
|
|||
}
|
||||
|
||||
const {
|
||||
body: { messages, connectorId, functions, functionCall },
|
||||
body: { name, messages, connectorId, functions, functionCall },
|
||||
} = params;
|
||||
|
||||
const controller = new AbortController();
|
||||
|
@ -55,7 +56,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
|
|||
controller.abort();
|
||||
});
|
||||
|
||||
const response$ = await client.chat({
|
||||
const response$ = await client.chat(name, {
|
||||
messages,
|
||||
connectorId,
|
||||
signal: controller.signal,
|
||||
|
|
|
@ -10,6 +10,7 @@ import type { ActionsClient } from '@kbn/actions-plugin/server';
|
|||
import type { ElasticsearchClient } from '@kbn/core/server';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import type { PublicMethodsOf } from '@kbn/utility-types';
|
||||
import apm from 'elastic-apm-node';
|
||||
import { decode, encode } from 'gpt-tokenizer';
|
||||
import { compact, isEmpty, last, merge, noop, omit, pick, take } from 'lodash';
|
||||
import type OpenAI from 'openai';
|
||||
|
@ -191,17 +192,22 @@ export class ObservabilityAIAssistantClient {
|
|||
return await next(nextMessages.concat(addedMessage));
|
||||
} else if (isUserMessage) {
|
||||
const response$ = (
|
||||
await this.chat({
|
||||
messages: nextMessages,
|
||||
connectorId,
|
||||
signal,
|
||||
functions:
|
||||
numFunctionsCalled >= MAX_FUNCTION_CALLS
|
||||
? []
|
||||
: functionClient
|
||||
.getFunctions()
|
||||
.map((fn) => pick(fn.definition, 'name', 'description', 'parameters')),
|
||||
})
|
||||
await this.chat(
|
||||
lastMessage.message.name && lastMessage.message.name !== 'recall'
|
||||
? 'function_response'
|
||||
: 'user_message',
|
||||
{
|
||||
messages: nextMessages,
|
||||
connectorId,
|
||||
signal,
|
||||
functions:
|
||||
numFunctionsCalled >= MAX_FUNCTION_CALLS
|
||||
? []
|
||||
: functionClient
|
||||
.getFunctions()
|
||||
.map((fn) => pick(fn.definition, 'name', 'description', 'parameters')),
|
||||
}
|
||||
)
|
||||
).pipe(emitWithConcatenatedMessage(), shareReplay());
|
||||
|
||||
response$.subscribe({
|
||||
|
@ -226,6 +232,14 @@ export class ObservabilityAIAssistantClient {
|
|||
}
|
||||
|
||||
if (isAssistantMessageWithFunctionRequest) {
|
||||
const span = apm.startSpan(
|
||||
`execute_function ${lastMessage.message.function_call!.name}`
|
||||
);
|
||||
|
||||
span?.addLabels({
|
||||
ai_assistant_args: JSON.stringify(lastMessage.message.function_call!.arguments ?? {}),
|
||||
});
|
||||
|
||||
const functionResponse =
|
||||
numFunctionsCalled >= MAX_FUNCTION_CALLS
|
||||
? {
|
||||
|
@ -247,6 +261,8 @@ export class ObservabilityAIAssistantClient {
|
|||
return response;
|
||||
}
|
||||
|
||||
span?.setOutcome('success');
|
||||
|
||||
const encoded = encode(JSON.stringify(response.content || {}));
|
||||
|
||||
if (encoded.length <= MAX_FUNCTION_RESPONSE_TOKEN_COUNT) {
|
||||
|
@ -263,6 +279,7 @@ export class ObservabilityAIAssistantClient {
|
|||
};
|
||||
})
|
||||
.catch((error): FunctionResponse => {
|
||||
span?.setOutcome('failure');
|
||||
return {
|
||||
content: {
|
||||
message: error.toString(),
|
||||
|
@ -322,8 +339,13 @@ export class ObservabilityAIAssistantClient {
|
|||
)
|
||||
);
|
||||
|
||||
span?.end();
|
||||
|
||||
return await next(nextMessages.concat(messageEvents.map((event) => event.message)));
|
||||
}
|
||||
|
||||
span?.end();
|
||||
|
||||
return await next(nextMessages);
|
||||
}
|
||||
|
||||
|
@ -401,19 +423,24 @@ export class ObservabilityAIAssistantClient {
|
|||
).pipe(shareReplay());
|
||||
};
|
||||
|
||||
chat = async ({
|
||||
messages,
|
||||
connectorId,
|
||||
functions,
|
||||
functionCall,
|
||||
signal,
|
||||
}: {
|
||||
messages: Message[];
|
||||
connectorId: string;
|
||||
functions?: Array<{ name: string; description: string; parameters: CompatibleJSONSchema }>;
|
||||
functionCall?: string;
|
||||
signal: AbortSignal;
|
||||
}): Promise<Observable<ChatCompletionChunkEvent>> => {
|
||||
chat = async (
|
||||
name: string,
|
||||
{
|
||||
messages,
|
||||
connectorId,
|
||||
functions,
|
||||
functionCall,
|
||||
signal,
|
||||
}: {
|
||||
messages: Message[];
|
||||
connectorId: string;
|
||||
functions?: Array<{ name: string; description: string; parameters: CompatibleJSONSchema }>;
|
||||
functionCall?: string;
|
||||
signal: AbortSignal;
|
||||
}
|
||||
): Promise<Observable<ChatCompletionChunkEvent>> => {
|
||||
const span = apm.startSpan(`chat ${name}`);
|
||||
|
||||
const messagesForOpenAI: Array<
|
||||
Omit<OpenAI.ChatCompletionMessageParam, 'role'> & {
|
||||
role: MessageRole;
|
||||
|
@ -481,7 +508,24 @@ export class ObservabilityAIAssistantClient {
|
|||
|
||||
signal.addEventListener('abort', () => response.destroy());
|
||||
|
||||
return streamIntoObservable(response).pipe(processOpenAiStream(), shareReplay());
|
||||
const observable = streamIntoObservable(response).pipe(processOpenAiStream(), shareReplay());
|
||||
|
||||
if (span) {
|
||||
lastValueFrom(observable)
|
||||
.then(
|
||||
() => {
|
||||
span.setOutcome('success');
|
||||
},
|
||||
() => {
|
||||
span.setOutcome('failure');
|
||||
}
|
||||
)
|
||||
.finally(() => {
|
||||
span.end();
|
||||
});
|
||||
}
|
||||
|
||||
return observable;
|
||||
};
|
||||
|
||||
find = async (options?: { query?: string }): Promise<{ conversations: Conversation[] }> => {
|
||||
|
@ -541,7 +585,7 @@ export class ObservabilityAIAssistantClient {
|
|||
connectorId: string;
|
||||
signal: AbortSignal;
|
||||
}) => {
|
||||
const response$ = await this.chat({
|
||||
const response$ = await this.chat('generate_title', {
|
||||
messages: [
|
||||
{
|
||||
'@timestamp': new Date().toISOString(),
|
||||
|
|
|
@ -69,6 +69,7 @@ export default function ApiTest({ getService }: FtrProviderContext) {
|
|||
.post(CHAT_API_URL)
|
||||
.set('kbn-xsrf', 'foo')
|
||||
.send({
|
||||
name: 'my_api_call',
|
||||
messages,
|
||||
connectorId: 'does not exist',
|
||||
functions: [],
|
||||
|
@ -96,6 +97,7 @@ export default function ApiTest({ getService }: FtrProviderContext) {
|
|||
.set('kbn-xsrf', 'foo')
|
||||
.on('error', reject)
|
||||
.send({
|
||||
name: 'my_api_call',
|
||||
messages,
|
||||
connectorId,
|
||||
functions: [],
|
||||
|
@ -136,6 +138,7 @@ export default function ApiTest({ getService }: FtrProviderContext) {
|
|||
.post(CHAT_API_URL)
|
||||
.set('kbn-xsrf', 'foo')
|
||||
.send({
|
||||
name: 'my_api_call',
|
||||
messages,
|
||||
connectorId,
|
||||
functions: [],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue