[Obs AI Assistant] Add APM instrumentation for chat/complete (#175799)

Adds APM instrumentation for function calls and interactions with the
LLM.

Closes https://github.com/elastic/obs-ai-assistant-team/issues/120
This commit is contained in:
Dario Gieselaar 2024-02-05 10:40:54 +01:00 committed by GitHub
parent ff3c1af8ee
commit 441ed1b662
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 150 additions and 88 deletions

View file

@ -57,7 +57,7 @@ describe('createChatService', () => {
}
function chat({ signal }: { signal: AbortSignal } = { signal: new AbortController().signal }) {
return service.chat({ signal, messages: [], connectorId: '' });
return service.chat('my_test', { signal, messages: [], connectorId: '' });
}
beforeEach(async () => {

View file

@ -174,17 +174,20 @@ export async function createChatService({
});
});
},
chat({
connectorId,
messages,
function: callFunctions = 'auto',
signal,
}: {
connectorId: string;
messages: Message[];
function?: 'none' | 'auto';
signal: AbortSignal;
}) {
chat(
name: string,
{
connectorId,
messages,
function: callFunctions = 'auto',
signal,
}: {
connectorId: string;
messages: Message[];
function?: 'none' | 'auto';
signal: AbortSignal;
}
) {
return new Observable<StreamingChatResponseEventWithoutError>((subscriber) => {
const contexts = ['core', 'apm'];
@ -193,6 +196,7 @@ export async function createChatService({
client('POST /internal/observability_ai_assistant/chat', {
params: {
body: {
name,
messages,
connectorId,
functions:

View file

@ -50,12 +50,15 @@ export type { PendingMessage };
export interface ObservabilityAIAssistantChatService {
analytics: AnalyticsServiceStart;
chat: (options: {
messages: Message[];
connectorId: string;
function?: 'none' | 'auto';
signal: AbortSignal;
}) => Observable<StreamingChatResponseEventWithoutError>;
chat: (
name: string,
options: {
messages: Message[];
connectorId: string;
function?: 'none' | 'auto';
signal: AbortSignal;
}
) => Observable<StreamingChatResponseEventWithoutError>;
complete: (options: {
messages: Message[];
connectorId: string;

View file

@ -186,17 +186,21 @@ export class KibanaClient {
unregister: () => void;
}> = [];
async function chat({
messages,
functions,
functionCall,
}: {
messages: Message[];
functions: FunctionDefinition[];
functionCall?: string;
}) {
async function chat(
name: string,
{
messages,
functions,
functionCall,
}: {
messages: Message[];
functions: FunctionDefinition[];
functionCall?: string;
}
) {
const params: ObservabilityAIAssistantAPIClientRequestParamsOf<'POST /internal/observability_ai_assistant/chat'>['params']['body'] =
{
name,
messages,
connectorId,
functions: functions.map((fn) => pick(fn, 'name', 'description', 'parameters')),
@ -235,7 +239,7 @@ export class KibanaClient {
'@timestamp': new Date().toISOString(),
})),
];
return chat({ messages, functions: functionDefinitions });
return chat('chat', { messages, functions: functionDefinitions });
},
complete: async (...args) => {
const messagesArg = args.length === 1 ? args[0] : args[1];
@ -298,7 +302,7 @@ export class KibanaClient {
};
},
evaluate: async ({ messages, conversationId }, criteria) => {
const message = await chat({
const message = await chat('evaluate', {
messages: [
{
'@timestamp': new Date().toISOString(),

View file

@ -124,7 +124,7 @@ export function registerEsqlFunction({
];
const source$ = (
await client.chat({
await client.chat('classify_esql', {
connectorId,
messages: withEsqlSystemMessage(
`Use the classify_esql function to classify the user's request
@ -198,10 +198,12 @@ export function registerEsqlFunction({
const messagesToInclude = mapValues(pick(esqlDocs, keywords), ({ data }) => data);
const esqlResponse$: Observable<ChatCompletionChunkEvent> = await client.chat({
messages: [
...withEsqlSystemMessage(
`Format every ES|QL query as Markdown:
const esqlResponse$: Observable<ChatCompletionChunkEvent> = await client.chat(
'answer_esql_question',
{
messages: [
...withEsqlSystemMessage(
`Format every ES|QL query as Markdown:
\`\`\`esql
<query>
\`\`\`
@ -224,33 +226,34 @@ export function registerEsqlFunction({
\`\`\`
`
),
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.Assistant,
content: '',
function_call: {
name: 'get_esql_info',
arguments: JSON.stringify(args),
trigger: MessageRole.Assistant as const,
),
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.Assistant,
content: '',
function_call: {
name: 'get_esql_info',
arguments: JSON.stringify(args),
trigger: MessageRole.Assistant as const,
},
},
},
},
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
name: 'get_esql_info',
content: JSON.stringify({
documentation: messagesToInclude,
}),
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
name: 'get_esql_info',
content: JSON.stringify({
documentation: messagesToInclude,
}),
},
},
},
],
connectorId,
signal,
});
],
connectorId,
signal,
}
);
return esqlResponse$.pipe(
emitWithConcatenatedMessage((msg) => {

View file

@ -115,7 +115,7 @@ export function registerGetDatasetInfoFunction({
const relevantFields = await Promise.all(
chunk(fieldNames, 500).map(async (fieldsInChunk) => {
const chunkResponse$ = (
await client.chat({
await client.chat('get_relevent_dataset_names', {
connectorId,
signal,
messages: [

View file

@ -248,7 +248,7 @@ async function scoreSuggestions({
const response = await lastValueFrom(
(
await client.chat({
await client.chat('score_suggestions', {
connectorId,
messages: [extendedSystemMessage, newUserMessage],
functions: [scoreFunction],

View file

@ -21,6 +21,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
params: t.type({
body: t.intersection([
t.type({
name: t.string,
messages: t.array(messageRt),
connectorId: t.string,
functions: t.array(
@ -46,7 +47,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
}
const {
body: { messages, connectorId, functions, functionCall },
body: { name, messages, connectorId, functions, functionCall },
} = params;
const controller = new AbortController();
@ -55,7 +56,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
controller.abort();
});
const response$ = await client.chat({
const response$ = await client.chat(name, {
messages,
connectorId,
signal: controller.signal,

View file

@ -10,6 +10,7 @@ import type { ActionsClient } from '@kbn/actions-plugin/server';
import type { ElasticsearchClient } from '@kbn/core/server';
import type { Logger } from '@kbn/logging';
import type { PublicMethodsOf } from '@kbn/utility-types';
import apm from 'elastic-apm-node';
import { decode, encode } from 'gpt-tokenizer';
import { compact, isEmpty, last, merge, noop, omit, pick, take } from 'lodash';
import type OpenAI from 'openai';
@ -191,17 +192,22 @@ export class ObservabilityAIAssistantClient {
return await next(nextMessages.concat(addedMessage));
} else if (isUserMessage) {
const response$ = (
await this.chat({
messages: nextMessages,
connectorId,
signal,
functions:
numFunctionsCalled >= MAX_FUNCTION_CALLS
? []
: functionClient
.getFunctions()
.map((fn) => pick(fn.definition, 'name', 'description', 'parameters')),
})
await this.chat(
lastMessage.message.name && lastMessage.message.name !== 'recall'
? 'function_response'
: 'user_message',
{
messages: nextMessages,
connectorId,
signal,
functions:
numFunctionsCalled >= MAX_FUNCTION_CALLS
? []
: functionClient
.getFunctions()
.map((fn) => pick(fn.definition, 'name', 'description', 'parameters')),
}
)
).pipe(emitWithConcatenatedMessage(), shareReplay());
response$.subscribe({
@ -226,6 +232,14 @@ export class ObservabilityAIAssistantClient {
}
if (isAssistantMessageWithFunctionRequest) {
const span = apm.startSpan(
`execute_function ${lastMessage.message.function_call!.name}`
);
span?.addLabels({
ai_assistant_args: JSON.stringify(lastMessage.message.function_call!.arguments ?? {}),
});
const functionResponse =
numFunctionsCalled >= MAX_FUNCTION_CALLS
? {
@ -247,6 +261,8 @@ export class ObservabilityAIAssistantClient {
return response;
}
span?.setOutcome('success');
const encoded = encode(JSON.stringify(response.content || {}));
if (encoded.length <= MAX_FUNCTION_RESPONSE_TOKEN_COUNT) {
@ -263,6 +279,7 @@ export class ObservabilityAIAssistantClient {
};
})
.catch((error): FunctionResponse => {
span?.setOutcome('failure');
return {
content: {
message: error.toString(),
@ -322,8 +339,13 @@ export class ObservabilityAIAssistantClient {
)
);
span?.end();
return await next(nextMessages.concat(messageEvents.map((event) => event.message)));
}
span?.end();
return await next(nextMessages);
}
@ -401,19 +423,24 @@ export class ObservabilityAIAssistantClient {
).pipe(shareReplay());
};
chat = async ({
messages,
connectorId,
functions,
functionCall,
signal,
}: {
messages: Message[];
connectorId: string;
functions?: Array<{ name: string; description: string; parameters: CompatibleJSONSchema }>;
functionCall?: string;
signal: AbortSignal;
}): Promise<Observable<ChatCompletionChunkEvent>> => {
chat = async (
name: string,
{
messages,
connectorId,
functions,
functionCall,
signal,
}: {
messages: Message[];
connectorId: string;
functions?: Array<{ name: string; description: string; parameters: CompatibleJSONSchema }>;
functionCall?: string;
signal: AbortSignal;
}
): Promise<Observable<ChatCompletionChunkEvent>> => {
const span = apm.startSpan(`chat ${name}`);
const messagesForOpenAI: Array<
Omit<OpenAI.ChatCompletionMessageParam, 'role'> & {
role: MessageRole;
@ -481,7 +508,24 @@ export class ObservabilityAIAssistantClient {
signal.addEventListener('abort', () => response.destroy());
return streamIntoObservable(response).pipe(processOpenAiStream(), shareReplay());
const observable = streamIntoObservable(response).pipe(processOpenAiStream(), shareReplay());
if (span) {
lastValueFrom(observable)
.then(
() => {
span.setOutcome('success');
},
() => {
span.setOutcome('failure');
}
)
.finally(() => {
span.end();
});
}
return observable;
};
find = async (options?: { query?: string }): Promise<{ conversations: Conversation[] }> => {
@ -541,7 +585,7 @@ export class ObservabilityAIAssistantClient {
connectorId: string;
signal: AbortSignal;
}) => {
const response$ = await this.chat({
const response$ = await this.chat('generate_title', {
messages: [
{
'@timestamp': new Date().toISOString(),

View file

@ -69,6 +69,7 @@ export default function ApiTest({ getService }: FtrProviderContext) {
.post(CHAT_API_URL)
.set('kbn-xsrf', 'foo')
.send({
name: 'my_api_call',
messages,
connectorId: 'does not exist',
functions: [],
@ -96,6 +97,7 @@ export default function ApiTest({ getService }: FtrProviderContext) {
.set('kbn-xsrf', 'foo')
.on('error', reject)
.send({
name: 'my_api_call',
messages,
connectorId,
functions: [],
@ -136,6 +138,7 @@ export default function ApiTest({ getService }: FtrProviderContext) {
.post(CHAT_API_URL)
.set('kbn-xsrf', 'foo')
.send({
name: 'my_api_call',
messages,
connectorId,
functions: [],