[Obs AI Assistant] Refactor ObservabilityAIAssistantClient (#181255)

Refactors the Observability AI Assistant server-side client. Instead of
using a mix of promises and Observables, we know use Observables where
possible. This leads to more readable code, and makes things like error
handling and logging easier.

This refactor purposely leaves the existing tests in place as much as
possible. The functionality has however been broken into separate
functions so we should be able to break up the existing tests into
smaller pieces.

---------

Co-authored-by: Kibana Machine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Dario Gieselaar 2024-04-30 14:14:52 +02:00 committed by GitHub
parent 8a1d2950fa
commit 6eba59575e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 1274 additions and 883 deletions

View file

@ -103,13 +103,17 @@ export type StreamingChatResponseEvent =
| ConversationCreateEvent
| ConversationUpdateEvent
| MessageAddEvent
| ChatCompletionErrorEvent;
| ChatCompletionErrorEvent
| TokenCountEvent;
export type StreamingChatResponseEventWithoutError = Exclude<
StreamingChatResponseEvent,
ChatCompletionErrorEvent
>;
export type ChatEvent = ChatCompletionChunkEvent | TokenCountEvent;
export type MessageOrChatEvent = ChatEvent | MessageAddEvent;
export enum ChatCompletionErrorCode {
InternalError = 'internalError',
NotFoundError = 'notFoundError',

View file

@ -14,7 +14,7 @@ export function createFunctionRequestMessage({
args,
}: {
name: string;
args: unknown;
args?: Record<string, any>;
}): MessageAddEvent {
return {
id: v4(),
@ -28,6 +28,7 @@ export function createFunctionRequestMessage({
trigger: MessageRole.Assistant as const,
},
role: MessageRole.Assistant,
content: '',
},
},
};

View file

@ -24,9 +24,11 @@ export function createFunctionResponseError({
name: error.name,
message: error.message,
cause: error.cause,
stack: error.stack,
},
message: message || error.message,
},
data: {
stack: error.stack,
},
});
}

View file

@ -5,9 +5,20 @@
* 2.0.
*/
import { concat, from, last, mergeMap, Observable, shareReplay, withLatestFrom } from 'rxjs';
import {
concat,
from,
last,
mergeMap,
Observable,
OperatorFunction,
shareReplay,
withLatestFrom,
} from 'rxjs';
import { withoutTokenCountEvents } from './without_token_count_events';
import {
ChatCompletionChunkEvent,
ChatEvent,
MessageAddEvent,
StreamingChatResponseEventType,
} from '../conversation_complete';
@ -40,20 +51,21 @@ function mergeWithEditedMessage(
);
}
export function emitWithConcatenatedMessage(
export function emitWithConcatenatedMessage<T extends ChatEvent>(
callback?: ConcatenateMessageCallback
): (
source$: Observable<ChatCompletionChunkEvent>
) => Observable<ChatCompletionChunkEvent | MessageAddEvent> {
return (source$: Observable<ChatCompletionChunkEvent>) => {
): OperatorFunction<T, T | MessageAddEvent> {
return (source$) => {
const shared = source$.pipe(shareReplay());
const withoutTokenCount$ = shared.pipe(withoutTokenCountEvents());
const response$ = concat(
shared,
shared.pipe(
withoutTokenCountEvents(),
concatenateChatCompletionChunks(),
last(),
withLatestFrom(source$),
withLatestFrom(withoutTokenCount$),
mergeMap(([message, chunkEvent]) => {
return mergeWithEditedMessage(message, chunkEvent, callback);
})

View file

@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { filter, OperatorFunction } from 'rxjs';
import {
StreamingChatResponseEvent,
StreamingChatResponseEventType,
TokenCountEvent,
} from '../conversation_complete';
export function withoutTokenCountEvents<T extends StreamingChatResponseEvent>(): OperatorFunction<
T,
Exclude<T, TokenCountEvent>
> {
return filter(
(event): event is Exclude<T, TokenCountEvent> =>
event.type !== StreamingChatResponseEventType.TokenCount
);
}

View file

@ -38,7 +38,7 @@ export const mockChatService: ObservabilityAIAssistantChatService = {
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.System,
content: '',
content: 'System',
},
}),
};

View file

@ -284,6 +284,7 @@ describe('complete', () => {
'@timestamp': expect.any(String),
message: {
content: expect.any(String),
data: expect.any(String),
name: 'my_action',
role: MessageRole.User,
},

View file

@ -28,7 +28,6 @@ import {
StreamingChatResponseEventType,
type StreamingChatResponseEventWithoutError,
type StreamingChatResponseEvent,
TokenCountEvent,
} from '../../common/conversation_complete';
import {
FunctionRegistry,
@ -163,13 +162,7 @@ export async function createChatService({
const subscription = toObservable(response)
.pipe(
map(
(line) =>
JSON.parse(line) as
| StreamingChatResponseEvent
| BufferFlushEvent
| TokenCountEvent
),
map((line) => JSON.parse(line) as StreamingChatResponseEvent | BufferFlushEvent),
filter(
(line): line is StreamingChatResponseEvent =>
line.type !== StreamingChatResponseEventType.BufferFlush &&

View file

@ -33,7 +33,7 @@ export const createStorybookChatService = (): ObservabilityAIAssistantChatServic
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.System,
content: '',
content: 'System',
},
}),
});

View file

@ -21,7 +21,7 @@ import { concatenateChatCompletionChunks } from '../../common/utils/concatenate_
import { createFunctionResponseMessage } from '../../common/utils/create_function_response_message';
import { RecallRanking, RecallRankingEventType } from '../analytics/recall_ranking';
import type { ObservabilityAIAssistantClient } from '../service/client';
import { ChatFn } from '../service/types';
import { FunctionCallChatFunction } from '../service/types';
import { parseSuggestionScores } from './parse_suggestion_scores';
const MAX_TOKEN_COUNT_FOR_DATA_ON_SCREEN = 1000;
@ -61,7 +61,7 @@ export function registerContextFunction({
required: ['queries', 'categories'],
} as const,
},
async ({ arguments: args, messages, connectorId, screenContexts, chat }, signal) => {
async ({ arguments: args, messages, screenContexts, chat }, signal) => {
const { analytics } = (await resources.context.core).coreStart;
const { queries, categories } = args;
@ -118,7 +118,6 @@ export function registerContextFunction({
queries: queriesOrUserPrompt,
messages,
chat,
connectorId,
signal,
logger: resources.logger,
});
@ -209,15 +208,13 @@ async function scoreSuggestions({
messages,
queries,
chat,
connectorId,
signal,
logger,
}: {
suggestions: Awaited<ReturnType<typeof retrieveSuggestions>>;
messages: Message[];
queries: string[];
chat: ChatFn;
connectorId: string;
chat: FunctionCallChatFunction;
signal: AbortSignal;
logger: Logger;
}) {
@ -274,15 +271,12 @@ async function scoreSuggestions({
};
const response = await lastValueFrom(
(
await chat('score_suggestions', {
connectorId,
messages: [...messages.slice(0, -2), newUserMessage],
functions: [scoreFunction],
functionCall: 'score',
signal,
})
).pipe(concatenateChatCompletionChunks())
chat('score_suggestions', {
messages: [...messages.slice(0, -2), newUserMessage],
functions: [scoreFunction],
functionCall: 'score',
signal,
}).pipe(concatenateChatCompletionChunks())
);
const scoreFunctionRequest = decodeOrThrow(scoreFunctionRequestRt)(response);

View file

@ -5,13 +5,13 @@
* 2.0.
*/
import datemath from '@elastic/datemath';
import type { DataViewsServerPluginStart } from '@kbn/data-views-plugin/server';
import type { ElasticsearchClient, SavedObjectsClientContract } from '@kbn/core/server';
import type { DataViewsServerPluginStart } from '@kbn/data-views-plugin/server';
import { castArray, chunk, groupBy, uniq } from 'lodash';
import { lastValueFrom, Observable } from 'rxjs';
import type { ObservabilityAIAssistantClient } from '../../service/client';
import { type ChatCompletionChunkEvent, type Message, MessageRole } from '../../../common';
import { lastValueFrom } from 'rxjs';
import { MessageRole, type Message } from '../../../common';
import { concatenateChatCompletionChunks } from '../../../common/utils/concatenate_chat_completion_chunks';
import { FunctionCallChatFunction } from '../../service/types';
export async function getRelevantFieldNames({
index,
@ -22,6 +22,7 @@ export async function getRelevantFieldNames({
savedObjectsClient,
chat,
messages,
signal,
}: {
index: string | string[];
start?: string;
@ -30,13 +31,8 @@ export async function getRelevantFieldNames({
esClient: ElasticsearchClient;
savedObjectsClient: SavedObjectsClientContract;
messages: Message[];
chat: (
name: string,
{}: Pick<
Parameters<ObservabilityAIAssistantClient['chat']>[1],
'functionCall' | 'functions' | 'messages'
>
) => Promise<Observable<ChatCompletionChunkEvent>>;
chat: FunctionCallChatFunction;
signal: AbortSignal;
}): Promise<{ fields: string[] }> {
const dataViewsService = await dataViews.dataViewsServiceFactory(savedObjectsClient, esClient);
@ -79,6 +75,7 @@ export async function getRelevantFieldNames({
chunk(fieldNames, 500).map(async (fieldsInChunk) => {
const chunkResponse$ = (
await chat('get_relevent_dataset_names', {
signal,
messages: [
{
'@timestamp': new Date().toISOString(),

View file

@ -37,7 +37,7 @@ export function registerGetDatasetInfoFunction({
required: ['index'],
} as const,
},
async ({ arguments: { index }, messages, connectorId, chat }, signal) => {
async ({ arguments: { index }, messages, chat }, signal) => {
const coreContext = await resources.context.core;
const esClient = coreContext.elasticsearch.client.asCurrentUser;
@ -83,18 +83,8 @@ export function registerGetDatasetInfoFunction({
esClient,
dataViews: await resources.plugins.dataViews.start(),
savedObjectsClient,
chat: (
operationName,
{ messages: nextMessages, functionCall, functions: nextFunctions }
) => {
return chat(operationName, {
messages: nextMessages,
functionCall,
functions: nextFunctions,
connectorId,
signal,
});
},
signal,
chat,
});
return {

View file

@ -8,12 +8,15 @@ import { notImplemented } from '@hapi/boom';
import { toBooleanRt } from '@kbn/io-ts-utils';
import * as t from 'io-ts';
import { Readable } from 'stream';
import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import { KibanaRequest } from '@kbn/core/server';
import { aiAssistantSimulatedFunctionCalling } from '../..';
import { flushBuffer } from '../../service/util/flush_buffer';
import { observableIntoStream } from '../../service/util/observable_into_stream';
import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route';
import { screenContextRt, messageRt, functionRt } from '../runtime_types';
import { ObservabilityAIAssistantRouteHandlerResources } from '../types';
import { withAssistantSpan } from '../../service/util/with_assistant_span';
const chatCompleteBaseRt = t.type({
body: t.intersection([
@ -57,6 +60,27 @@ const chatCompletePublicRt = t.intersection([
}),
]);
async function guardAgainstInvalidConnector({
actions,
request,
connectorId,
}: {
actions: ActionsPluginStart;
request: KibanaRequest;
connectorId: string;
}) {
return withAssistantSpan('guard_against_invalid_connector', async () => {
const actionsClient = await actions.getActionsClientWithRequest(request);
const connector = await actionsClient.get({
id: connectorId,
throwIfSystemAction: true,
});
return connector;
});
}
const chatRoute = createObservabilityAIAssistantServerRoute({
endpoint: 'POST /internal/observability_ai_assistant/chat',
options: {
@ -76,7 +100,17 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
]),
}),
handler: async (resources): Promise<Readable> => {
const { request, params, service, context } = resources;
const { request, params, service, context, plugins } = resources;
const {
body: { name, messages, connectorId, functions, functionCall },
} = params;
await guardAgainstInvalidConnector({
actions: await plugins.actions.start(),
request,
connectorId,
});
const [client, cloudStart, simulateFunctionCalling] = await Promise.all([
service.getClient({ request }),
@ -88,17 +122,13 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
throw notImplemented();
}
const {
body: { name, messages, connectorId, functions, functionCall },
} = params;
const controller = new AbortController();
request.events.aborted$.subscribe(() => {
controller.abort();
});
const response$ = await client.chat(name, {
const response$ = client.chat(name, {
messages,
connectorId,
signal: controller.signal,
@ -120,19 +150,7 @@ async function chatComplete(
params: t.TypeOf<typeof chatCompleteInternalRt>;
}
) {
const { request, params, service } = resources;
const [client, cloudStart, simulateFunctionCalling] = await Promise.all([
service.getClient({ request }),
resources.plugins.cloud?.start() || Promise.resolve(undefined),
(
await resources.context.core
).uiSettings.client.get<boolean>(aiAssistantSimulatedFunctionCalling),
]);
if (!client) {
throw notImplemented();
}
const { request, params, service, plugins } = resources;
const {
body: {
@ -147,6 +165,24 @@ async function chatComplete(
},
} = params;
await guardAgainstInvalidConnector({
actions: await plugins.actions.start(),
request,
connectorId,
});
const [client, cloudStart, simulateFunctionCalling] = await Promise.all([
service.getClient({ request }),
resources.plugins.cloud?.start() || Promise.resolve(undefined),
(
await resources.context.core
).uiSettings.client.get<boolean>(aiAssistantSimulatedFunctionCalling),
]);
if (!client) {
throw notImplemented();
}
const controller = new AbortController();
request.events.aborted$.subscribe(() => {

View file

@ -48,7 +48,6 @@ describe('chatFunctionClient', () => {
}),
messages: [],
signal: new AbortController().signal,
connectorId: '',
});
}).rejects.toThrowError(`Function arguments are invalid`);
@ -107,7 +106,6 @@ describe('chatFunctionClient', () => {
name: 'get_data_on_screen',
args: JSON.stringify({ data: ['my_dummy_data'] }),
messages: [],
connectorId: '',
signal: new AbortController().signal,
});

View file

@ -13,7 +13,7 @@ import { FunctionVisibility, type FunctionResponse } from '../../../common/funct
import type { Message, ObservabilityAIAssistantScreenContextRequest } from '../../../common/types';
import { filterFunctionDefinitions } from '../../../common/utils/filter_function_definitions';
import type {
ChatFn,
FunctionCallChatFunction,
FunctionHandler,
FunctionHandlerRegistry,
RegisteredInstruction,
@ -144,14 +144,12 @@ export class ChatFunctionClient {
args,
messages,
signal,
connectorId,
}: {
chat: ChatFn;
chat: FunctionCallChatFunction;
name: string;
args: string | undefined;
messages: Message[];
signal: AbortSignal;
connectorId: string;
}): Promise<FunctionResponse> {
const fn = this.functionRegistry.get(name);
@ -167,7 +165,6 @@ export class ChatFunctionClient {
{
arguments: parsedArguments,
messages,
connectorId,
screenContexts: this.screenContexts,
chat,
},

View file

@ -11,9 +11,9 @@ import { Logger } from '@kbn/logging';
import { concatenateChatCompletionChunks } from '../../../../../common/utils/concatenate_chat_completion_chunks';
import { processBedrockStream } from './process_bedrock_stream';
import { MessageRole } from '../../../../../common';
import { rejectTokenCountEvents } from '../../../util/reject_token_count_events';
import { TOOL_USE_END, TOOL_USE_START } from '../simulate_function_calling/constants';
import { parseInlineFunctionCalls } from '../simulate_function_calling/parse_inline_function_calls';
import { withoutTokenCountEvents } from '../../../../../common/utils/without_token_count_events';
describe('processBedrockStream', () => {
const encodeChunk = (body: unknown) => {
@ -69,7 +69,7 @@ describe('processBedrockStream', () => {
parseInlineFunctionCalls({
logger: getLoggerMock(),
}),
rejectTokenCountEvents(),
withoutTokenCountEvents(),
concatenateChatCompletionChunks()
)
)
@ -101,7 +101,7 @@ describe('processBedrockStream', () => {
parseInlineFunctionCalls({
logger: getLoggerMock(),
}),
rejectTokenCountEvents(),
withoutTokenCountEvents(),
concatenateChatCompletionChunks()
)
)
@ -135,7 +135,7 @@ describe('processBedrockStream', () => {
parseInlineFunctionCalls({
logger: getLoggerMock(),
}),
rejectTokenCountEvents(),
withoutTokenCountEvents(),
concatenateChatCompletionChunks()
)
)
@ -167,7 +167,7 @@ describe('processBedrockStream', () => {
parseInlineFunctionCalls({
logger: getLoggerMock(),
}),
rejectTokenCountEvents(),
withoutTokenCountEvents(),
concatenateChatCompletionChunks()
)
);
@ -193,7 +193,7 @@ describe('processBedrockStream', () => {
parseInlineFunctionCalls({
logger: getLoggerMock(),
}),
rejectTokenCountEvents(),
withoutTokenCountEvents(),
concatenateChatCompletionChunks()
)
)

View file

@ -5,28 +5,24 @@
* 2.0.
*/
import { noop } from 'lodash';
import { forkJoin, last, Observable, shareReplay, tap } from 'rxjs';
import {
ChatCompletionChunkEvent,
createFunctionNotFoundError,
FunctionDefinition,
} from '../../../../common';
import { TokenCountEvent } from '../../../../common/conversation_complete';
import { ignoreElements, last, merge, Observable, shareReplay, tap } from 'rxjs';
import { createFunctionNotFoundError, FunctionDefinition } from '../../../../common';
import { ChatEvent } from '../../../../common/conversation_complete';
import { concatenateChatCompletionChunks } from '../../../../common/utils/concatenate_chat_completion_chunks';
import { rejectTokenCountEvents } from '../../util/reject_token_count_events';
import { withoutTokenCountEvents } from '../../../../common/utils/without_token_count_events';
export function failOnNonExistingFunctionCall({
functions,
}: {
functions?: Array<Pick<FunctionDefinition, 'name' | 'description' | 'parameters'>>;
}) {
return (source$: Observable<ChatCompletionChunkEvent | TokenCountEvent>) => {
return new Observable<ChatCompletionChunkEvent | TokenCountEvent>((subscriber) => {
const shared = source$.pipe(shareReplay());
return (source$: Observable<ChatEvent>) => {
const shared$ = source$.pipe(shareReplay());
const checkFunctionCallResponse$ = shared.pipe(
rejectTokenCountEvents(),
return merge(
shared$,
shared$.pipe(
withoutTokenCountEvents(),
concatenateChatCompletionChunks(),
last(),
tap((event) => {
@ -36,24 +32,9 @@ export function failOnNonExistingFunctionCall({
) {
throw createFunctionNotFoundError(event.message.function_call.name);
}
})
);
source$.subscribe({
next: (val) => {
subscriber.next(val);
},
error: noop,
});
forkJoin([source$, checkFunctionCallResponse$]).subscribe({
complete: () => {
subscriber.complete();
},
error: (error) => {
subscriber.error(error);
},
});
});
}),
ignoreElements()
)
);
};
}

View file

@ -9,10 +9,7 @@ import type { Readable } from 'node:stream';
import type { Observable } from 'rxjs';
import type { Logger } from '@kbn/logging';
import type { Message } from '../../../../common';
import type {
ChatCompletionChunkEvent,
TokenCountEvent,
} from '../../../../common/conversation_complete';
import type { ChatEvent } from '../../../../common/conversation_complete';
import { CompatibleJSONSchema } from '../../../../common/functions/types';
export interface LlmFunction {
@ -31,7 +28,5 @@ export type LlmApiAdapterFactory = (options: {
export interface LlmApiAdapter {
getSubAction: () => { subAction: string; subActionParams: Record<string, any> };
streamIntoObservable: (
readable: Readable
) => Observable<ChatCompletionChunkEvent | TokenCountEvent>;
streamIntoObservable: (readable: Readable) => Observable<ChatEvent>;
}

View file

@ -0,0 +1,35 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { findLastIndex } from 'lodash';
import { Message, MessageAddEvent, MessageRole } from '../../../common';
import { createFunctionRequestMessage } from '../../../common/utils/create_function_request_message';
export function getContextFunctionRequestIfNeeded(
messages: Message[]
): MessageAddEvent | undefined {
const indexOfLastUserMessage = findLastIndex(
messages,
(message) => message.message.role === MessageRole.User && !message.message.name
);
const hasContextSinceLastUserMessage = messages
.slice(indexOfLastUserMessage)
.some((message) => message.message.name === 'context');
if (hasContextSinceLastUserMessage) {
return undefined;
}
return createFunctionRequestMessage({
name: 'context',
args: {
queries: [],
categories: [],
},
});
}

View file

@ -39,12 +39,14 @@ const nextTick = () => {
return new Promise(process.nextTick);
};
const waitForNextWrite = async (stream: Readable): Promise<void> => {
const waitForNextWrite = async (stream: Readable): Promise<any> => {
// this will fire before the client's internal write() promise is
// resolved
await new Promise((resolve) => stream.once('data', resolve));
const response = await new Promise((resolve) => stream.once('data', resolve));
// so we wait another tick to let the client move to the next step
await nextTick();
return response;
};
function createLlmSimulator() {
@ -108,12 +110,7 @@ describe('Observability AI Assistant client', () => {
getInstructions: jest.fn(),
} as any;
const loggerMock: DeeplyMockedKeys<Logger> = {
log: jest.fn(),
error: jest.fn(),
debug: jest.fn(),
trace: jest.fn(),
} as any;
let loggerMock: DeeplyMockedKeys<Logger> = {} as any;
const functionClientMock: DeeplyMockedKeys<ChatFunctionClient> = {
executeFunction: jest.fn(),
@ -130,6 +127,18 @@ describe('Observability AI Assistant client', () => {
function createClient() {
jest.resetAllMocks();
// uncomment this line for debugging
// const consoleOrPassThrough = console.log.bind(console);
const consoleOrPassThrough = () => {};
loggerMock = {
log: jest.fn().mockImplementation(consoleOrPassThrough),
error: jest.fn().mockImplementation(consoleOrPassThrough),
debug: jest.fn().mockImplementation(consoleOrPassThrough),
trace: jest.fn().mockImplementation(consoleOrPassThrough),
isLevelEnabled: jest.fn().mockReturnValue(true),
} as any;
functionClientMock.getFunctions.mockReturnValue([]);
functionClientMock.hasFunction.mockImplementation((name) => {
return name !== 'context';
@ -214,24 +223,27 @@ describe('Observability AI Assistant client', () => {
beforeEach(async () => {
client = createClient();
actionsClientMock.execute
.mockImplementationOnce(() => {
.mockImplementationOnce((body) => {
return new Promise((resolve, reject) => {
titleLlmPromiseResolve = (title: string) => {
const titleLlmSimulator = createLlmSimulator();
titleLlmSimulator.next({ content: title });
titleLlmSimulator.complete();
resolve({
actionId: '',
status: 'ok',
data: titleLlmSimulator.stream,
});
titleLlmSimulator
.next({ content: title })
.then(() => titleLlmSimulator.complete())
.then(() => {
resolve({
actionId: '',
status: 'ok',
data: titleLlmSimulator.stream,
});
});
};
titleLlmPromiseReject = () => {
reject();
titleLlmPromiseReject = (error: Error) => {
reject(error);
};
});
})
.mockImplementationOnce(async () => {
.mockImplementationOnce(async (body) => {
llmSimulator = createLlmSimulator();
return {
actionId: '',
@ -260,6 +272,8 @@ describe('Observability AI Assistant client', () => {
stream.on('data', dataHandler);
await llmSimulator.next({ content: 'Hello' });
await nextTick();
});
it('calls the actions client with the messages', () => {
@ -346,9 +360,9 @@ describe('Observability AI Assistant client', () => {
id: expect.any(String),
last_updated: expect.any(String),
token_count: {
completion: 2,
prompt: 156,
total: 158,
completion: 1,
prompt: 78,
total: 79,
},
},
type: StreamingChatResponseEventType.ConversationCreate,
@ -364,8 +378,6 @@ describe('Observability AI Assistant client', () => {
titleLlmPromiseResolve('An auto-generated title');
await nextTick();
await llmSimulator.complete();
await finished(stream);
@ -405,9 +417,9 @@ describe('Observability AI Assistant client', () => {
id: expect.any(String),
last_updated: expect.any(String),
token_count: {
completion: 8,
prompt: 340,
total: 348,
completion: 6,
prompt: 262,
total: 268,
},
},
type: StreamingChatResponseEventType.ConversationCreate,
@ -423,9 +435,9 @@ describe('Observability AI Assistant client', () => {
last_updated: expect.any(String),
title: 'An auto-generated title',
token_count: {
completion: 8,
prompt: 340,
total: 348,
completion: 6,
prompt: 262,
total: 268,
},
},
labels: {},
@ -477,7 +489,7 @@ describe('Observability AI Assistant client', () => {
beforeEach(async () => {
client = createClient();
actionsClientMock.execute.mockImplementationOnce(async () => {
actionsClientMock.execute.mockImplementationOnce(async (body) => {
llmSimulator = createLlmSimulator();
return {
actionId: '',
@ -499,6 +511,11 @@ describe('Observability AI Assistant client', () => {
id: 'my-conversation-id',
title: 'My stored conversation',
last_updated: new Date().toISOString(),
token_count: {
completion: 1,
prompt: 78,
total: 79,
},
},
labels: {},
numeric_labels: {},
@ -694,7 +711,7 @@ describe('Observability AI Assistant client', () => {
beforeEach(async () => {
client = createClient();
actionsClientMock.execute.mockImplementationOnce(async () => {
actionsClientMock.execute.mockImplementationOnce(async (body) => {
llmSimulator = createLlmSimulator();
return {
actionId: '',
@ -794,7 +811,6 @@ describe('Observability AI Assistant client', () => {
it('executes the function', () => {
expect(functionClientMock.executeFunction).toHaveBeenCalledWith({
connectorId: 'foo',
name: 'myFunction',
chat: expect.any(Function),
args: JSON.stringify({ foo: 'bar' }),
@ -832,6 +848,7 @@ describe('Observability AI Assistant client', () => {
afterEach(async () => {
fnResponseResolve({ content: { my: 'content' } });
await waitForNextWrite(stream);
await llmSimulator.complete();
@ -993,7 +1010,12 @@ describe('Observability AI Assistant client', () => {
});
it('appends the function response', () => {
expect(JSON.parse(dataHandler.mock.lastCall!)).toEqual({
const parsed = JSON.parse(dataHandler.mock.lastCall!);
parsed.message.message.content = JSON.parse(parsed.message.message.content);
parsed.message.message.data = JSON.parse(parsed.message.message.data);
expect(parsed).toEqual({
type: StreamingChatResponseEventType.MessageAdd,
id: expect.any(String),
message: {
@ -1001,10 +1023,16 @@ describe('Observability AI Assistant client', () => {
message: {
role: MessageRole.User,
name: 'myFunction',
content: JSON.stringify({
message: 'Error: Function failed',
error: {},
}),
content: {
message: 'Function failed',
error: {
name: 'Error',
message: 'Function failed',
},
},
data: {
stack: expect.any(String),
},
},
},
});
@ -1138,7 +1166,7 @@ describe('Observability AI Assistant client', () => {
let dataHandler: jest.Mock;
beforeEach(async () => {
client = createClient();
actionsClientMock.execute.mockImplementationOnce(async () => {
actionsClientMock.execute.mockImplementationOnce(async (body) => {
llmSimulator = createLlmSimulator();
return {
actionId: '',
@ -1149,7 +1177,7 @@ describe('Observability AI Assistant client', () => {
functionClientMock.hasFunction.mockReturnValue(true);
functionClientMock.executeFunction.mockImplementationOnce(async () => {
functionClientMock.executeFunction.mockImplementationOnce(async (body) => {
return {
content: [
{
@ -1327,14 +1355,14 @@ describe('Observability AI Assistant client', () => {
await nextTick();
for (let i = 0; i <= maxFunctionCalls + 1; i++) {
for (let i = 0; i <= maxFunctionCalls; i++) {
await requestAlertsFunctionCall();
}
await finished(stream);
});
it('executed the function no more than three times', () => {
it(`executed the function no more than ${maxFunctionCalls} times`, () => {
expect(functionClientMock.executeFunction).toHaveBeenCalledTimes(maxFunctionCalls);
});

View file

@ -0,0 +1,294 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { decode, encode } from 'gpt-tokenizer';
import { pick, take } from 'lodash';
import {
catchError,
concat,
EMPTY,
from,
isObservable,
Observable,
of,
OperatorFunction,
shareReplay,
switchMap,
throwError,
} from 'rxjs';
import { createFunctionNotFoundError, Message, MessageRole } from '../../../../common';
import {
createFunctionLimitExceededError,
MessageOrChatEvent,
} from '../../../../common/conversation_complete';
import { FunctionVisibility } from '../../../../common/functions/types';
import { UserInstruction } from '../../../../common/types';
import { createFunctionResponseError } from '../../../../common/utils/create_function_response_error';
import { createFunctionResponseMessage } from '../../../../common/utils/create_function_response_message';
import { emitWithConcatenatedMessage } from '../../../../common/utils/emit_with_concatenated_message';
import { withoutTokenCountEvents } from '../../../../common/utils/without_token_count_events';
import type { ChatFunctionClient } from '../../chat_function_client';
import type { ChatFunctionWithoutConnector } from '../../types';
import { getSystemMessageFromInstructions } from '../../util/get_system_message_from_instructions';
import { replaceSystemMessage } from '../../util/replace_system_message';
import { extractMessages } from './extract_messages';
import { hideTokenCountEvents } from './hide_token_count_events';
const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000;
function executeFunctionAndCatchError({
name,
args,
functionClient,
messages,
chat,
signal,
}: {
name: string;
args: string | undefined;
functionClient: ChatFunctionClient;
messages: Message[];
chat: ChatFunctionWithoutConnector;
signal: AbortSignal;
}): Observable<MessageOrChatEvent> {
// hide token count events from functions to prevent them from
// having to deal with it as well
return hideTokenCountEvents((hide) => {
const executeFunctionResponse$ = from(
functionClient.executeFunction({
name,
chat: (operationName, params) => {
return chat(operationName, params).pipe(hide());
},
args,
signal,
messages,
})
);
return executeFunctionResponse$.pipe(
catchError((error) => {
// We want to catch the error only when a promise occurs
// if it occurs in the Observable, we cannot easily recover
// from it because the function may have already emitted
// values which could lead to an invalid conversation state,
// so in that case we let the stream fail.
return of(createFunctionResponseError({ name, error }));
}),
switchMap((response) => {
if (isObservable(response)) {
return response;
}
// is messageAdd event
if ('type' in response) {
return of(response);
}
const encoded = encode(JSON.stringify(response.content || {}));
const exceededTokenLimit = encoded.length >= MAX_FUNCTION_RESPONSE_TOKEN_COUNT;
return of(
createFunctionResponseMessage({
name,
content: exceededTokenLimit
? {
message:
'Function response exceeded the maximum length allowed and was truncated',
truncated: decode(take(encoded, MAX_FUNCTION_RESPONSE_TOKEN_COUNT)),
}
: response.content,
data: response.data,
})
);
})
);
});
}
function getFunctionDefinitions({
functionClient,
functionLimitExceeded,
}: {
functionClient: ChatFunctionClient;
functionLimitExceeded: boolean;
}) {
const systemFunctions = functionLimitExceeded
? []
: functionClient
.getFunctions()
.map((fn) => fn.definition)
.filter(
(def) =>
!def.visibility ||
[FunctionVisibility.AssistantOnly, FunctionVisibility.All].includes(def.visibility)
);
const actions = functionLimitExceeded ? [] : functionClient.getActions();
const allDefinitions = systemFunctions
.concat(actions)
.map((definition) => pick(definition, 'name', 'description', 'parameters'));
return allDefinitions;
}
export function continueConversation({
messages: initialMessages,
functionClient,
chat,
signal,
functionCallsLeft,
requestInstructions,
knowledgeBaseInstructions,
}: {
messages: Message[];
functionClient: ChatFunctionClient;
chat: ChatFunctionWithoutConnector;
signal: AbortSignal;
functionCallsLeft: number;
requestInstructions: Array<string | UserInstruction>;
knowledgeBaseInstructions: UserInstruction[];
}): Observable<MessageOrChatEvent> {
let nextFunctionCallsLeft = functionCallsLeft;
const definitions = getFunctionDefinitions({
functionLimitExceeded: functionCallsLeft <= 0,
functionClient,
});
const messagesWithUpdatedSystemMessage = replaceSystemMessage(
getSystemMessageFromInstructions({
registeredInstructions: functionClient.getInstructions(),
knowledgeBaseInstructions,
requestInstructions,
availableFunctionNames: definitions.map((def) => def.name),
}),
initialMessages
);
const lastMessage =
messagesWithUpdatedSystemMessage[messagesWithUpdatedSystemMessage.length - 1].message;
const isUserMessage = lastMessage.role === MessageRole.User;
return executeNextStep().pipe(handleEvents());
function executeNextStep() {
if (isUserMessage) {
const operationName =
lastMessage.name && lastMessage.name !== 'context'
? `function_response ${lastMessage.name}`
: 'user_message';
return chat(operationName, {
messages: messagesWithUpdatedSystemMessage,
functions: definitions,
}).pipe(emitWithConcatenatedMessage());
}
const functionCallName = lastMessage.function_call?.name;
if (!functionCallName) {
// reply from the LLM without a function request,
// so we can close the stream and wait for input from the user
return EMPTY;
}
// we know we are executing a function here, so we can already
// subtract one, and reference the old count for if clauses
const currentFunctionCallsLeft = nextFunctionCallsLeft;
nextFunctionCallsLeft--;
const isAction = functionCallName && functionClient.hasAction(functionCallName);
if (currentFunctionCallsLeft === 0) {
// create a function call response error so the LLM knows it needs to stop calling functions
return of(
createFunctionResponseError({
name: functionCallName,
error: createFunctionLimitExceededError(),
})
);
}
if (currentFunctionCallsLeft < 0) {
// LLM tried calling it anyway, throw an error
return throwError(() => createFunctionLimitExceededError());
}
// if it's an action, we close the stream and wait for the action response
// from the client/browser
if (isAction) {
try {
functionClient.validate(
functionCallName,
JSON.parse(lastMessage.function_call!.arguments || '{}')
);
} catch (error) {
// return a function response error for the LLM to handle
return of(
createFunctionResponseError({
name: functionCallName,
error,
})
);
}
return EMPTY;
}
if (!functionClient.hasFunction(functionCallName)) {
// tell the LLM the function was not found
return of(
createFunctionResponseError({
name: functionCallName,
error: createFunctionNotFoundError(functionCallName),
})
);
}
return executeFunctionAndCatchError({
name: functionCallName,
args: lastMessage.function_call!.arguments,
chat,
functionClient,
messages: messagesWithUpdatedSystemMessage,
signal,
});
}
function handleEvents(): OperatorFunction<MessageOrChatEvent, MessageOrChatEvent> {
return (events$) => {
const shared$ = events$.pipe(shareReplay());
return concat(
shared$,
shared$.pipe(
withoutTokenCountEvents(),
extractMessages(),
switchMap((extractedMessages) => {
if (!extractedMessages.length) {
return EMPTY;
}
return continueConversation({
messages: messagesWithUpdatedSystemMessage.concat(extractedMessages),
chat,
functionCallsLeft: nextFunctionCallsLeft,
functionClient,
signal,
knowledgeBaseInstructions,
requestInstructions,
});
})
)
);
};
}
}

View file

@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { inspect } from 'util';
import { dematerialize, materialize, OperatorFunction, tap } from 'rxjs';
export function debug<T>(prefix: string): OperatorFunction<T, T> {
return (source$) => {
return source$.pipe(
materialize(),
tap((event) => {
// eslint-disable-next-line no-console
console.log(prefix + ':\n' + inspect(event, { depth: 10 }));
}),
dematerialize()
);
};
}

View file

@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { filter, last, map, OperatorFunction, toArray } from 'rxjs';
import { Message, MessageAddEvent, StreamingChatResponseEventType } from '../../../../common';
import type { MessageOrChatEvent } from '../../../../common/conversation_complete';
export function extractMessages(): OperatorFunction<MessageOrChatEvent, Message[]> {
return (source$) => {
return source$.pipe(
filter(
(event): event is MessageAddEvent =>
event.type === StreamingChatResponseEventType.MessageAdd
),
map((event) => event.message),
toArray(),
last()
);
};
}

View file

@ -0,0 +1,36 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { filter, OperatorFunction, scan } from 'rxjs';
import {
StreamingChatResponseEvent,
StreamingChatResponseEventType,
TokenCountEvent,
} from '../../../../common/conversation_complete';
export function extractTokenCount(): OperatorFunction<
StreamingChatResponseEvent,
TokenCountEvent['tokens']
> {
return (events$) => {
return events$.pipe(
filter(
(event): event is TokenCountEvent =>
event.type === StreamingChatResponseEventType.TokenCount
),
scan(
(acc, event) => {
acc.completion += event.tokens.completion;
acc.prompt += event.tokens.prompt;
acc.total += event.tokens.total;
return acc;
},
{ completion: 0, prompt: 0, total: 0 }
)
);
};
}

View file

@ -0,0 +1,105 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { catchError, map, Observable, of, tap } from 'rxjs';
import { Logger } from '@kbn/logging';
import type { ObservabilityAIAssistantClient } from '..';
import { Message, MessageRole } from '../../../../common';
import { concatenateChatCompletionChunks } from '../../../../common/utils/concatenate_chat_completion_chunks';
import { hideTokenCountEvents } from './hide_token_count_events';
import { ChatEvent, TokenCountEvent } from '../../../../common/conversation_complete';
type ChatFunctionWithoutConnectorAndTokenCount = (
name: string,
params: Omit<
Parameters<ObservabilityAIAssistantClient['chat']>[1],
'connectorId' | 'signal' | 'simulateFunctionCalling'
>
) => Observable<ChatEvent>;
export function getGeneratedTitle({
responseLanguage,
messages,
chat,
logger,
}: {
responseLanguage?: string;
messages: Message[];
chat: ChatFunctionWithoutConnectorAndTokenCount;
logger: Logger;
}): Observable<string | TokenCountEvent> {
return hideTokenCountEvents((hide) =>
chat('generate_title', {
messages: [
{
'@timestamp': new Date().toString(),
message: {
role: MessageRole.System,
content: `You are a helpful assistant for Elastic Observability. Assume the following message is the start of a conversation between you and a user; give this conversation a title based on the content below. DO NOT UNDER ANY CIRCUMSTANCES wrap this title in single or double quotes. This title is shown in a list of conversations to the user, so title it for the user, not for you. Please create the title in ${responseLanguage}.`,
},
},
{
'@timestamp': new Date().toISOString(),
message: {
role: MessageRole.User,
content: messages.slice(1).reduce((acc, curr) => {
return `${acc} ${curr.message.role}: ${curr.message.content}`;
}, 'Generate a title, using the title_conversation_function, based on the following conversation:\n\n'),
},
},
],
functions: [
{
name: 'title_conversation',
description:
'Use this function to title the conversation. Do not wrap the title in quotes',
parameters: {
type: 'object',
properties: {
title: {
type: 'string',
},
},
required: ['title'],
},
},
],
functionCall: 'title_conversation',
}).pipe(
hide(),
concatenateChatCompletionChunks(),
map((concatenatedMessage) => {
const input =
(concatenatedMessage.message.function_call.name
? JSON.parse(concatenatedMessage.message.function_call.arguments).title
: concatenatedMessage.message?.content) || '';
// This regular expression captures a string enclosed in single or double quotes.
// It extracts the string content without the quotes.
// Example matches:
// - "Hello, World!" => Captures: Hello, World!
// - 'Another Example' => Captures: Another Example
// - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes
const match = input.match(/^["']?([^"']+)["']?$/);
const title = match ? match[1] : input;
return title;
}),
tap((event) => {
if (typeof event === 'string') {
logger.debug(`Generated title: ${event}`);
}
})
)
).pipe(
catchError((error) => {
logger.error(`Error generating title`);
logger.error(error);
// TODO: i18n
return of('New conversation');
})
);
}

View file

@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { merge, Observable, partition } from 'rxjs';
import type { StreamingChatResponseEvent } from '../../../../common';
import {
StreamingChatResponseEventType,
TokenCountEvent,
} from '../../../../common/conversation_complete';
type Hide = <T extends StreamingChatResponseEvent>() => (
source$: Observable<T | TokenCountEvent>
) => Observable<Exclude<T, TokenCountEvent>>;
export function hideTokenCountEvents<T>(
cb: (hide: Hide) => Observable<Exclude<T, TokenCountEvent>>
): Observable<T | TokenCountEvent> {
// `hide` can be called multiple times, so we keep track of each invocation
const allInterceptors: Array<Observable<TokenCountEvent>> = [];
const hide: Hide = () => (source$) => {
const [tokenCountEvents$, otherEvents$] = partition(
source$,
(value): value is TokenCountEvent => value.type === StreamingChatResponseEventType.TokenCount
);
allInterceptors.push(tokenCountEvents$);
return otherEvents$;
};
// combine the two observables again
return merge(cb(hide), ...allInterceptors);
}

View file

@ -0,0 +1,71 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import apm from 'elastic-apm-node';
import {
catchError,
ignoreElements,
merge,
OperatorFunction,
shareReplay,
tap,
last,
throwError,
finalize,
} from 'rxjs';
import type { StreamingChatResponseEvent } from '../../../../common/conversation_complete';
import { extractTokenCount } from './extract_token_count';
export function instrumentAndCountTokens<T extends StreamingChatResponseEvent>(
name: string
): OperatorFunction<T, T> {
return (source$) => {
const span = apm.startSpan(name);
if (!span) {
return source$;
}
span?.addLabels({
plugin: 'observability_ai_assistant',
});
const shared$ = source$.pipe(shareReplay());
let tokenCount = {
prompt: 0,
completion: 0,
total: 0,
};
return merge(
shared$,
shared$.pipe(
extractTokenCount(),
tap((nextTokenCount) => {
tokenCount = nextTokenCount;
}),
last(),
tap(() => {
span?.setOutcome('success');
}),
catchError((error) => {
span?.setOutcome('failure');
return throwError(() => error);
}),
finalize(() => {
span?.addLabels({
tokenCountPrompt: tokenCount.prompt,
tokenCountCompletion: tokenCount.completion,
tokenCountTotal: tokenCount.total,
});
span?.end();
}),
ignoreElements()
)
);
};
}

View file

@ -7,7 +7,7 @@
import type { FromSchema } from 'json-schema-to-ts';
import { Observable } from 'rxjs';
import { ChatCompletionChunkEvent } from '../../common/conversation_complete';
import { ChatCompletionChunkEvent, ChatEvent } from '../../common/conversation_complete';
import type {
CompatibleJSONSchema,
FunctionDefinition,
@ -27,17 +27,33 @@ export type RespondFunctionResources = Pick<
'context' | 'logger' | 'plugins' | 'request'
>;
export type ChatFn = (
...args: Parameters<ObservabilityAIAssistantClient['chat']>
) => Promise<Observable<ChatCompletionChunkEvent>>;
export type ChatFunction = (
name: string,
params: Parameters<ObservabilityAIAssistantClient['chat']>[1]
) => Observable<ChatEvent>;
export type ChatFunctionWithoutConnector = (
name: string,
params: Omit<
Parameters<ObservabilityAIAssistantClient['chat']>[1],
'connectorId' | 'simulateFunctionCalling' | 'signal'
>
) => Observable<ChatEvent>;
export type FunctionCallChatFunction = (
name: string,
params: Omit<
Parameters<ObservabilityAIAssistantClient['chat']>[1],
'connectorId' | 'simulateFunctionCalling'
>
) => Observable<ChatCompletionChunkEvent>;
type RespondFunction<TArguments, TResponse extends FunctionResponse> = (
options: {
arguments: TArguments;
messages: Message[];
connectorId: string;
screenContexts: ObservabilityAIAssistantScreenContextRequest[];
chat: ChatFn;
chat: FunctionCallChatFunction;
},
signal: AbortSignal
) => Promise<TResponse>;

View file

@ -9,16 +9,15 @@ import { i18n } from '@kbn/i18n';
import { catchError, filter, of, OperatorFunction, shareReplay, throwError } from 'rxjs';
import {
ChatCompletionChunkEvent,
MessageAddEvent,
MessageRole,
StreamingChatResponseEventType,
} from '../../../common';
import { isFunctionNotFoundError } from '../../../common/conversation_complete';
import { isFunctionNotFoundError, MessageOrChatEvent } from '../../../common/conversation_complete';
import { emitWithConcatenatedMessage } from '../../../common/utils/emit_with_concatenated_message';
export function catchFunctionLimitExceededError(): OperatorFunction<
ChatCompletionChunkEvent | MessageAddEvent,
ChatCompletionChunkEvent | MessageAddEvent
MessageOrChatEvent,
MessageOrChatEvent
> {
return (source$) => {
const shared$ = source$.pipe(shareReplay());

View file

@ -13,11 +13,10 @@ import {
isChatCompletionError,
StreamingChatResponseEventType,
StreamingChatResponseEventWithoutError,
TokenCountEvent,
} from '../../../common/conversation_complete';
export function observableIntoStream(
source: Observable<StreamingChatResponseEventWithoutError | BufferFlushEvent | TokenCountEvent>
source: Observable<StreamingChatResponseEventWithoutError | BufferFlushEvent>
) {
const stream = new PassThrough();

View file

@ -1,26 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { filter, Observable } from 'rxjs';
import {
ChatCompletionChunkEvent,
StreamingChatResponseEventType,
TokenCountEvent,
} from '../../../common/conversation_complete';
export function rejectTokenCountEvents() {
return <T extends ChatCompletionChunkEvent | TokenCountEvent>(
source: Observable<T>
): Observable<Exclude<T, TokenCountEvent>> => {
return source.pipe(
filter(
(event): event is Exclude<T, TokenCountEvent> =>
event.type !== StreamingChatResponseEventType.TokenCount
)
);
};
}

View file

@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { withSpan, SpanOptions, parseSpanOptions } from '@kbn/apm-utils';
export function withAssistantSpan<T>(
optionsOrName: SpanOptions | string,
cb: () => Promise<T>
): Promise<T> {
const options = parseSpanOptions(optionsOrName);
const optionsWithDefaults = {
...(options.intercept ? {} : { type: 'plugin:observability_ai_assistant' }),
...options,
labels: {
plugin: 'observability_ai_assistant',
...options.labels,
},
};
return withSpan(optionsWithDefaults, cb);
}

View file

@ -48,6 +48,7 @@
"@kbn/cloud-plugin",
"@kbn/serverless",
"@kbn/triggers-actions-ui-plugin",
"@kbn/apm-utils"
],
"exclude": ["target/**/*"]
}

View file

@ -142,7 +142,7 @@ export function registerQueryFunction({ functions, resources }: FunctionRegistra
description: `This function generates, executes and/or visualizes a query based on the user's request. It also explains how ES|QL works and how to convert queries from one language to another. Make sure you call one of the get_dataset functions first if you need index or field names. This function takes no input.`,
visibility: FunctionVisibility.AssistantOnly,
},
async ({ messages, connectorId, chat }, signal) => {
async ({ messages, chat }, signal) => {
const [systemMessage, esqlDocs] = await Promise.all([loadSystemMessage(), loadEsqlDocs()]);
const withEsqlSystemMessage = (message?: string) => [
@ -155,7 +155,6 @@ export function registerQueryFunction({ functions, resources }: FunctionRegistra
const source$ = (
await chat('classify_esql', {
connectorId,
messages: withEsqlSystemMessage().concat({
'@timestamp': new Date().toISOString(),
message: {
@ -382,7 +381,6 @@ export function registerQueryFunction({ functions, resources }: FunctionRegistra
},
},
],
connectorId,
signal,
functions: functions.getActions(),
});

View file

@ -27,25 +27,22 @@ export function registerVisualizeESQLFunction({
functions,
resources,
}: FunctionRegistrationParameters) {
functions.registerFunction(
visualizeESQLFunction,
async ({ arguments: { query, intention }, connectorId, messages }, signal) => {
const { columns, errorMessages } = await validateEsqlQuery({
query,
client: (await resources.context.core).elasticsearch.client.asCurrentUser,
});
functions.registerFunction(visualizeESQLFunction, async ({ arguments: { query, intention } }) => {
const { columns, errorMessages } = await validateEsqlQuery({
query,
client: (await resources.context.core).elasticsearch.client.asCurrentUser,
});
const message = getMessageForLLM(intention, query, Boolean(errorMessages?.length));
const message = getMessageForLLM(intention, query, Boolean(errorMessages?.length));
return {
data: {
columns,
},
content: {
message,
errorMessages,
},
};
}
);
return {
data: {
columns,
},
content: {
message,
errorMessages,
},
};
});
}

View file

@ -160,6 +160,7 @@ export default function ApiTest({ getService }: FtrProviderContext) {
connectorId,
functions: [],
})
.expect(200)
.pipe(passThrough);
let data: string = '';
@ -188,9 +189,9 @@ export default function ApiTest({ getService }: FtrProviderContext) {
await new Promise<void>((resolve) => passThrough.on('end', () => resolve()));
const response = JSON.parse(data);
const response = JSON.parse(data.trim());
expect(response.message).to.be(
expect(response.error.message).to.be(
`Token limit reached. Token limit is 8192, but the current conversation has 11036 tokens.`
);
});