[Obs AI Assistant] Remove TokenCountEvent (#209549)

Closes https://github.com/elastic/kibana/issues/205479

This filters out the `ChatCompletionTokenCountEvent` from the inference
plugin. This greatly simplifies handling ChatCompletion events in the
Obs AI Assistant.
This commit is contained in:
Søren Louv-Jansen 2025-02-20 18:59:19 +01:00 committed by GitHub
parent 283cb29606
commit c4826bdfbf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 96 additions and 541 deletions

View file

@ -6,7 +6,7 @@
*/
import { i18n } from '@kbn/i18n';
import { TokenCount as TokenCountType, type Message } from './types';
import { type Message } from './types';
export enum StreamingChatResponseEventType {
ChatCompletionChunk = 'chatCompletionChunk',
@ -16,7 +16,6 @@ export enum StreamingChatResponseEventType {
MessageAdd = 'messageAdd',
ChatCompletionError = 'chatCompletionError',
BufferFlush = 'bufferFlush',
TokenCount = 'tokenCount',
}
type StreamingChatResponseEventBase<
@ -54,7 +53,6 @@ export type ConversationCreateEvent = StreamingChatResponseEventBase<
id: string;
title: string;
last_updated: string;
token_count?: TokenCountType;
};
}
>;
@ -66,7 +64,6 @@ export type ConversationUpdateEvent = StreamingChatResponseEventBase<
id: string;
title: string;
last_updated: string;
token_count?: TokenCountType;
};
}
>;
@ -95,17 +92,6 @@ export type BufferFlushEvent = StreamingChatResponseEventBase<
}
>;
export type TokenCountEvent = StreamingChatResponseEventBase<
StreamingChatResponseEventType.TokenCount,
{
tokens: {
completion: number;
prompt: number;
total: number;
};
}
>;
export type StreamingChatResponseEvent =
| ChatCompletionChunkEvent
| ChatCompletionMessageEvent
@ -113,7 +99,6 @@ export type StreamingChatResponseEvent =
| ConversationUpdateEvent
| MessageAddEvent
| ChatCompletionErrorEvent
| TokenCountEvent
| BufferFlushEvent;
export type StreamingChatResponseEventWithoutError = Exclude<
@ -121,7 +106,7 @@ export type StreamingChatResponseEventWithoutError = Exclude<
ChatCompletionErrorEvent
>;
export type ChatEvent = ChatCompletionChunkEvent | TokenCountEvent | ChatCompletionMessageEvent;
export type ChatEvent = ChatCompletionChunkEvent | ChatCompletionMessageEvent;
export type MessageOrChatEvent = ChatEvent | MessageAddEvent;
export enum ChatCompletionErrorCode {

View file

@ -18,7 +18,6 @@ export {
export type {
ChatCompletionChunkEvent,
ChatCompletionMessageEvent,
TokenCountEvent,
ConversationCreateEvent,
ConversationUpdateEvent,
MessageAddEvent,

View file

@ -45,12 +45,6 @@ export interface Message {
};
}
export interface TokenCount {
prompt: number;
completion: number;
total: number;
}
export interface Conversation {
'@timestamp': string;
user?: {
@ -61,7 +55,6 @@ export interface Conversation {
id: string;
title: string;
last_updated: string;
token_count?: TokenCount;
};
systemMessage?: string;
messages: Message[];
@ -71,8 +64,8 @@ export interface Conversation {
public: boolean;
}
export type ConversationRequestBase = Omit<Conversation, 'user' | 'conversation' | 'namespace'> & {
conversation: { title: string; token_count?: TokenCount; id?: string };
type ConversationRequestBase = Omit<Conversation, 'user' | 'conversation' | 'namespace'> & {
conversation: { title: string; id?: string };
};
export type ConversationCreateRequest = ConversationRequestBase;

View file

@ -16,7 +16,6 @@ import {
withLatestFrom,
filter,
} from 'rxjs';
import { withoutTokenCountEvents } from './without_token_count_events';
import {
type ChatCompletionChunkEvent,
ChatEvent,
@ -69,15 +68,12 @@ export function emitWithConcatenatedMessage<T extends ChatEvent>(
return (source$) => {
const shared = source$.pipe(shareReplay());
const withoutTokenCount$ = shared.pipe(filterChunkEvents());
const response$ = concat(
shared,
shared.pipe(
withoutTokenCountEvents(),
concatenateChatCompletionChunks(),
last(),
withLatestFrom(withoutTokenCount$),
withLatestFrom(shared.pipe(filterChunkEvents())),
mergeMap(([message, chunkEvent]) => {
return mergeWithEditedMessage(message, chunkEvent, callback);
})

View file

@ -1,23 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { filter, OperatorFunction } from 'rxjs';
import {
StreamingChatResponseEvent,
StreamingChatResponseEventType,
TokenCountEvent,
} from '../conversation_complete';
export function withoutTokenCountEvents<T extends StreamingChatResponseEvent>(): OperatorFunction<
T,
Exclude<T, TokenCountEvent>
> {
return filter(
(event): event is Exclude<T, TokenCountEvent> =>
event.type !== StreamingChatResponseEventType.TokenCount
);
}

View file

@ -65,28 +65,6 @@ export const chatFeedbackEventSchema: EventTypeOpts<ChatFeedback> = {
description: 'The timestamp of the last message in the conversation.',
},
},
token_count: {
properties: {
completion: {
type: 'long',
_meta: {
description: 'The number of tokens in the completion.',
},
},
prompt: {
type: 'long',
_meta: {
description: 'The number of tokens in the prompt.',
},
},
total: {
type: 'long',
_meta: {
description: 'The total number of tokens in the conversation.',
},
},
},
},
},
},
labels: {

View file

@ -13,7 +13,6 @@ import { Readable } from 'stream';
import { AssistantScope } from '@kbn/ai-assistant-common';
import { aiAssistantSimulatedFunctionCalling } from '../..';
import { createFunctionResponseMessage } from '../../../common/utils/create_function_response_message';
import { withoutTokenCountEvents } from '../../../common/utils/without_token_count_events';
import { LangTracer } from '../../service/client/instrumentation/lang_tracer';
import { flushBuffer } from '../../service/util/flush_buffer';
import { observableIntoOpenAIStream } from '../../service/util/observable_into_openai_stream';
@ -203,16 +202,14 @@ const chatRecallRoute = createObservabilityAIAssistantServerRoute({
recallAndScore({
analytics: (await resources.plugins.core.start()).analytics,
chat: (name, params) =>
client
.chat(name, {
...params,
stream: true,
connectorId,
simulateFunctionCalling,
signal,
tracer: new LangTracer(otelContext.active()),
})
.pipe(withoutTokenCountEvents()),
client.chat(name, {
...params,
stream: true,
connectorId,
simulateFunctionCalling,
signal,
tracer: new LangTracer(otelContext.active()),
}),
context,
logger: resources.logger,
messages: [],

View file

@ -7,9 +7,7 @@
import * as t from 'io-ts';
import { toBooleanRt } from '@kbn/io-ts-utils';
import {
type Conversation,
type ConversationCreateRequest,
type ConversationRequestBase,
type ConversationUpdateRequest,
type Message,
MessageRole,
@ -57,17 +55,12 @@ const tokenCountRt = t.type({
total: t.number,
});
export const baseConversationRt: t.Type<ConversationRequestBase> = t.intersection([
export const conversationCreateRt: t.Type<ConversationCreateRequest> = t.intersection([
t.type({
'@timestamp': t.string,
conversation: t.intersection([
t.type({
title: t.string,
}),
t.partial({
token_count: tokenCountRt,
}),
]),
conversation: t.type({
title: t.string,
}),
messages: t.array(messageRt),
labels: t.record(t.string, t.string),
numeric_labels: t.record(t.string, t.number),
@ -84,17 +77,8 @@ export const assistantScopeType = t.union([
t.literal('all'),
]);
export const conversationCreateRt: t.Type<ConversationCreateRequest> = t.intersection([
baseConversationRt,
t.type({
conversation: t.type({
title: t.string,
}),
}),
]);
export const conversationUpdateRt: t.Type<ConversationUpdateRequest> = t.intersection([
baseConversationRt,
conversationCreateRt,
t.type({
conversation: t.intersection([
t.type({
@ -102,33 +86,12 @@ export const conversationUpdateRt: t.Type<ConversationUpdateRequest> = t.interse
title: t.string,
}),
t.partial({
token_count: tokenCountRt,
token_count: tokenCountRt, // deprecated, but kept for backwards compatibility
}),
]),
}),
]);
export const conversationRt: t.Type<Conversation> = t.intersection([
baseConversationRt,
t.intersection([
t.type({
namespace: t.string,
conversation: t.intersection([
t.type({
id: t.string,
last_updated: t.string,
}),
t.partial({
token_count: tokenCountRt,
}),
]),
}),
t.partial({
user: t.intersection([t.type({ name: t.string }), t.partial({ id: t.string })]),
}),
]),
]);
export const functionRt = t.intersection([
t.type({
name: t.string,

View file

@ -69,21 +69,6 @@ function createLlmSimulator(subscriber: any) {
toolCalls: msg.function_call ? [{ function: msg.function_call }] : [],
});
},
tokenCount: async ({
completion,
prompt,
total,
}: {
completion: number;
prompt: number;
total: number;
}) => {
subscriber.next({
type: InferenceChatCompletionEventType.ChatCompletionTokenCount,
tokens: { completion, prompt, total },
});
subscriber.complete();
},
chunk: async (msg: ChunkDelta) => {
subscriber.next({
type: InferenceChatCompletionEventType.ChatCompletionChunk,
@ -252,11 +237,6 @@ describe('Observability AI Assistant client', () => {
},
},
],
tokens: {
completion: 0,
prompt: 0,
total: 0,
},
})
.catch((error) => titleLlmSimulator.error(error));
};
@ -388,7 +368,6 @@ describe('Observability AI Assistant client', () => {
titleLlmPromiseReject(new Error('Failed generating title'));
await nextTick();
await llmSimulator.tokenCount({ completion: 1, prompt: 33, total: 34 });
await llmSimulator.complete();
await finished(stream);
@ -400,11 +379,6 @@ describe('Observability AI Assistant client', () => {
title: 'New conversation',
id: expect.any(String),
last_updated: expect.any(String),
token_count: {
completion: 1,
prompt: 33,
total: 34,
},
},
type: StreamingChatResponseEventType.ConversationCreate,
});
@ -418,7 +392,6 @@ describe('Observability AI Assistant client', () => {
await llmSimulator.chunk({ content: ' again' });
titleLlmPromiseResolve('An auto-generated title');
await llmSimulator.tokenCount({ completion: 6, prompt: 210, total: 216 });
await llmSimulator.complete();
await finished(stream);
@ -457,11 +430,6 @@ describe('Observability AI Assistant client', () => {
title: 'An auto-generated title',
id: expect.any(String),
last_updated: expect.any(String),
token_count: {
completion: 6,
prompt: 210,
total: 216,
},
},
type: StreamingChatResponseEventType.ConversationCreate,
});
@ -475,11 +443,6 @@ describe('Observability AI Assistant client', () => {
id: expect.any(String),
last_updated: expect.any(String),
title: 'An auto-generated title',
token_count: {
completion: 6,
prompt: 210,
total: 216,
},
},
labels: {},
numeric_labels: {},
@ -543,11 +506,6 @@ describe('Observability AI Assistant client', () => {
id: 'my-conversation-id',
title: 'My stored conversation',
last_updated: new Date().toISOString(),
token_count: {
completion: 1,
prompt: 78,
total: 79,
},
},
labels: {},
numeric_labels: {},
@ -583,7 +541,6 @@ describe('Observability AI Assistant client', () => {
await llmSimulator.chunk({ content: 'Hello' });
await llmSimulator.next({ content: 'Hello' });
await llmSimulator.tokenCount({ completion: 1, prompt: 33, total: 34 });
await llmSimulator.complete();
await finished(stream);
@ -595,11 +552,6 @@ describe('Observability AI Assistant client', () => {
title: 'My stored conversation',
id: expect.any(String),
last_updated: expect.any(String),
token_count: {
completion: 2,
prompt: 111,
total: 113,
},
},
type: StreamingChatResponseEventType.ConversationUpdate,
});
@ -614,11 +566,6 @@ describe('Observability AI Assistant client', () => {
id: expect.any(String),
last_updated: expect.any(String),
title: 'My stored conversation',
token_count: {
completion: 2,
prompt: 111,
total: 113,
},
},
labels: {},
numeric_labels: {},
@ -900,7 +847,6 @@ describe('Observability AI Assistant client', () => {
beforeEach(async () => {
await llmSimulator.chunk({ content: 'I am done here' });
await llmSimulator.next({ content: 'I am done here' });
await llmSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
await llmSimulator.complete();
await waitForNextWrite(stream);
@ -940,11 +886,6 @@ describe('Observability AI Assistant client', () => {
id: expect.any(String),
last_updated: expect.any(String),
title: 'My predefined title',
token_count: {
completion: expect.any(Number),
prompt: expect.any(Number),
total: expect.any(Number),
},
},
});
@ -1200,7 +1141,6 @@ describe('Observability AI Assistant client', () => {
await llmSimulator.chunk({ content: 'Hello' });
await llmSimulator.next({ content: 'Hello' });
await llmSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
await llmSimulator.complete();
await finished(stream);
@ -1492,7 +1432,6 @@ describe('Observability AI Assistant client', () => {
await llmSimulator.chunk({ function_call: { name: 'get_top_alerts' } });
await llmSimulator.next({ content: 'done' });
await llmSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
await llmSimulator.complete();
await waitFor(() => functionResponsePromiseResolve !== undefined);
@ -1612,7 +1551,6 @@ describe('Observability AI Assistant client', () => {
function_call: { name: 'my_action', arguments: JSON.stringify({ foo: 'bar' }) },
});
await llmSimulator.next({ content: 'content' });
await llmSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
await llmSimulator.complete();
});
@ -1643,7 +1581,6 @@ describe('Observability AI Assistant client', () => {
await llmSimulator.next({
content: 'Looks like the function call failed',
});
await llmSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
await llmSimulator.complete();
});

View file

@ -41,7 +41,6 @@ import {
ConversationUpdateEvent,
createConversationNotFoundError,
StreamingChatResponseEventType,
TokenCountEvent,
type StreamingChatResponseEvent,
} from '../../../common/conversation_complete';
import { convertMessagesForInference } from '../../../common/convert_messages_for_inference';
@ -56,7 +55,6 @@ import {
KnowledgeBaseType,
KnowledgeBaseEntryRole,
} from '../../../common/types';
import { withoutTokenCountEvents } from '../../../common/utils/without_token_count_events';
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
import type { ChatFunctionClient } from '../chat_function_client';
import { KnowledgeBaseService, RecalledEntry } from '../knowledge_base_service';
@ -68,9 +66,7 @@ import { LangTracer } from './instrumentation/lang_tracer';
import { continueConversation } from './operators/continue_conversation';
import { convertInferenceEventsToStreamingEvents } from './operators/convert_inference_events_to_streaming_events';
import { extractMessages } from './operators/extract_messages';
import { extractTokenCount } from './operators/extract_token_count';
import { getGeneratedTitle } from './operators/get_generated_title';
import { instrumentAndCountTokens } from './operators/instrument_and_count_tokens';
import {
reIndexKnowledgeBaseAndPopulateSemanticTextField,
scheduleKbSemanticTextMigrationTask,
@ -78,6 +74,7 @@ import {
import { ObservabilityAIAssistantPluginStartDependencies } from '../../types';
import { ObservabilityAIAssistantConfig } from '../../config';
import { getElserModelId } from '../knowledge_base_service/get_elser_model_id';
import { apmInstrumentation } from './operators/apm_instrumentation';
const MAX_FUNCTION_CALLS = 8;
@ -297,17 +294,12 @@ export class ObservabilityAIAssistantClient {
// wait until all dependencies have completed
forkJoin([
// get just the new messages
nextEvents$.pipe(withoutTokenCountEvents(), extractMessages()),
// count all the token count events emitted during completion
mergeOperator(
nextEvents$,
title$.pipe(filter((value): value is TokenCountEvent => typeof value !== 'string'))
).pipe(extractTokenCount()),
nextEvents$.pipe(extractMessages()),
// get just the title, and drop the token count events
title$.pipe(filter((value): value is string => typeof value === 'string')),
systemMessage$,
]).pipe(
switchMap(([addedMessages, tokenCountResult, title, systemMessage]) => {
switchMap(([addedMessages, title, systemMessage]) => {
const initialMessagesWithAddedMessages = initialMessages.concat(addedMessages);
const lastMessage = last(initialMessagesWithAddedMessages);
@ -329,13 +321,6 @@ export class ObservabilityAIAssistantClient {
return throwError(() => createConversationNotFoundError());
}
const persistedTokenCount = conversation._source?.conversation
.token_count ?? {
prompt: 0,
completion: 0,
total: 0,
};
return from(
this.update(
conversationId,
@ -349,16 +334,10 @@ export class ObservabilityAIAssistantClient {
// update messages and system message
{ messages: initialMessagesWithAddedMessages, systemMessage },
// update token count
// update title
{
conversation: {
title: title || conversation._source?.conversation.title,
token_count: {
prompt: persistedTokenCount.prompt + tokenCountResult.prompt,
completion:
persistedTokenCount.completion + tokenCountResult.completion,
total: persistedTokenCount.total + tokenCountResult.total,
},
},
}
)
@ -382,7 +361,6 @@ export class ObservabilityAIAssistantClient {
conversation: {
title,
id: conversationId,
token_count: tokenCountResult,
},
public: !!isPublic,
labels: {},
@ -403,8 +381,7 @@ export class ObservabilityAIAssistantClient {
);
return output$.pipe(
instrumentAndCountTokens('complete'),
withoutTokenCountEvents(),
apmInstrumentation('complete'),
catchError((error) => {
this.dependencies.logger.error(error);
return throwError(() => error);
@ -462,7 +439,7 @@ export class ObservabilityAIAssistantClient {
stream: TStream;
}
): TStream extends true
? Observable<ChatCompletionChunkEvent | TokenCountEvent | ChatCompletionMessageEvent>
? Observable<ChatCompletionChunkEvent | ChatCompletionMessageEvent>
: Promise<ChatCompleteResponse> {
let tools: Record<string, { description: string; schema: any }> | undefined;
let toolChoice: ToolChoiceType | { function: string } | undefined;
@ -500,7 +477,7 @@ export class ObservabilityAIAssistantClient {
})
).pipe(
convertInferenceEventsToStreamingEvents(),
instrumentAndCountTokens(name),
apmInstrumentation(name),
failOnNonExistingFunctionCall({ functions }),
tap((event) => {
if (
@ -512,7 +489,7 @@ export class ObservabilityAIAssistantClient {
}),
shareReplay()
) as TStream extends true
? Observable<ChatCompletionChunkEvent | TokenCountEvent | ChatCompletionMessageEvent>
? Observable<ChatCompletionChunkEvent | ChatCompletionMessageEvent>
: never;
} else {
return this.dependencies.inferenceClient.chatComplete({

View file

@ -18,9 +18,8 @@ import {
finalize,
} from 'rxjs';
import type { StreamingChatResponseEvent } from '../../../../common/conversation_complete';
import { extractTokenCount } from './extract_token_count';
export function instrumentAndCountTokens<T extends StreamingChatResponseEvent>(
export function apmInstrumentation<T extends StreamingChatResponseEvent>(
name: string
): OperatorFunction<T, T> {
return (source$) => {
@ -35,19 +34,9 @@ export function instrumentAndCountTokens<T extends StreamingChatResponseEvent>(
const shared$ = source$.pipe(shareReplay());
let tokenCount = {
prompt: 0,
completion: 0,
total: 0,
};
return merge(
shared$,
shared$.pipe(
extractTokenCount(),
tap((nextTokenCount) => {
tokenCount = nextTokenCount;
}),
last(),
tap(() => {
span?.setOutcome('success');
@ -57,11 +46,6 @@ export function instrumentAndCountTokens<T extends StreamingChatResponseEvent>(
return throwError(() => error);
}),
finalize(() => {
span?.addLabels({
tokenCountPrompt: tokenCount.prompt,
tokenCountCompletion: tokenCount.completion,
tokenCountTotal: tokenCount.total,
});
span?.end();
}),
ignoreElements()

View file

@ -31,14 +31,12 @@ import { FunctionVisibility } from '../../../../common/functions/types';
import { AdHocInstruction, Instruction } from '../../../../common/types';
import { createFunctionResponseMessage } from '../../../../common/utils/create_function_response_message';
import { emitWithConcatenatedMessage } from '../../../../common/utils/emit_with_concatenated_message';
import { withoutTokenCountEvents } from '../../../../common/utils/without_token_count_events';
import type { ChatFunctionClient } from '../../chat_function_client';
import type { AutoAbortedChatFunction } from '../../types';
import { createServerSideFunctionResponseError } from '../../util/create_server_side_function_response_error';
import { LangTracer } from '../instrumentation/lang_tracer';
import { catchFunctionNotFoundError } from './catch_function_not_found_error';
import { extractMessages } from './extract_messages';
import { hideTokenCountEvents } from './hide_token_count_events';
const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000;
@ -68,67 +66,65 @@ function executeFunctionAndCatchError({
// hide token count events from functions to prevent them from
// having to deal with it as well
return tracer.startActiveSpan(`execute_function ${name}`, ({ tracer: nextTracer }) =>
hideTokenCountEvents((hide) => {
const executeFunctionResponse$ = from(
functionClient.executeFunction({
name,
chat: (operationName, params) => {
return chat(operationName, {
...params,
tracer: nextTracer,
connectorId,
}).pipe(hide());
},
args,
signal,
messages,
connectorId,
simulateFunctionCalling,
})
);
return tracer.startActiveSpan(`execute_function ${name}`, ({ tracer: nextTracer }) => {
const executeFunctionResponse$ = from(
functionClient.executeFunction({
name,
chat: (operationName, params) => {
return chat(operationName, {
...params,
tracer: nextTracer,
connectorId,
});
},
args,
signal,
messages,
connectorId,
simulateFunctionCalling,
})
);
return executeFunctionResponse$.pipe(
catchError((error) => {
logger.error(`Encountered error running function ${name}: ${JSON.stringify(error)}`);
// We want to catch the error only when a promise occurs
// if it occurs in the Observable, we cannot easily recover
// from it because the function may have already emitted
// values which could lead to an invalid conversation state,
// so in that case we let the stream fail.
return of(createServerSideFunctionResponseError({ name, error }));
}),
switchMap((response) => {
if (isObservable(response)) {
return response;
}
return executeFunctionResponse$.pipe(
catchError((error) => {
logger.error(`Encountered error running function ${name}: ${JSON.stringify(error)}`);
// We want to catch the error only when a promise occurs
// if it occurs in the Observable, we cannot easily recover
// from it because the function may have already emitted
// values which could lead to an invalid conversation state,
// so in that case we let the stream fail.
return of(createServerSideFunctionResponseError({ name, error }));
}),
switchMap((response) => {
if (isObservable(response)) {
return response;
}
// is messageAdd event
if ('type' in response) {
return of(response);
}
// is messageAdd event
if ('type' in response) {
return of(response);
}
const encoded = encode(JSON.stringify(response.content || {}));
const encoded = encode(JSON.stringify(response.content || {}));
const exceededTokenLimit = encoded.length >= MAX_FUNCTION_RESPONSE_TOKEN_COUNT;
const exceededTokenLimit = encoded.length >= MAX_FUNCTION_RESPONSE_TOKEN_COUNT;
return of(
createFunctionResponseMessage({
name,
content: exceededTokenLimit
? {
message:
'Function response exceeded the maximum length allowed and was truncated',
truncated: decode(take(encoded, MAX_FUNCTION_RESPONSE_TOKEN_COUNT)),
}
: response.content,
data: response.data,
})
);
})
);
})
);
return of(
createFunctionResponseMessage({
name,
content: exceededTokenLimit
? {
message:
'Function response exceeded the maximum length allowed and was truncated',
truncated: decode(take(encoded, MAX_FUNCTION_RESPONSE_TOKEN_COUNT)),
}
: response.content,
data: response.data,
})
);
})
);
});
}
function getFunctionDefinitions({
@ -315,7 +311,6 @@ export function continueConversation({
return concat(
shared$,
shared$.pipe(
withoutTokenCountEvents(),
extractMessages(),
switchMap((extractedMessages) => {
if (!extractedMessages.length) {

View file

@ -5,7 +5,7 @@
* 2.0.
*/
import { Observable, OperatorFunction, map } from 'rxjs';
import { Observable, OperatorFunction, filter, map } from 'rxjs';
import { v4 } from 'uuid';
import {
ChatCompletionEvent as InferenceChatCompletionEvent,
@ -13,17 +13,17 @@ import {
} from '@kbn/inference-common';
import {
ChatCompletionChunkEvent,
TokenCountEvent,
ChatCompletionMessageEvent,
StreamingChatResponseEventType,
} from '../../../../common';
export function convertInferenceEventsToStreamingEvents(): OperatorFunction<
InferenceChatCompletionEvent,
ChatCompletionChunkEvent | TokenCountEvent | ChatCompletionMessageEvent
ChatCompletionChunkEvent | ChatCompletionMessageEvent
> {
return (events$: Observable<InferenceChatCompletionEvent>) => {
return events$.pipe(
filter((event) => event.type !== InferenceChatCompletionEventType.ChatCompletionTokenCount),
map((event) => {
switch (event.type) {
case InferenceChatCompletionEventType.ChatCompletionChunk:
@ -42,16 +42,6 @@ export function convertInferenceEventsToStreamingEvents(): OperatorFunction<
: undefined,
},
} as ChatCompletionChunkEvent;
case InferenceChatCompletionEventType.ChatCompletionTokenCount:
// Convert to TokenCountEvent
return {
type: StreamingChatResponseEventType.TokenCount,
tokens: {
completion: event.tokens.completion,
prompt: event.tokens.prompt,
total: event.tokens.total,
},
} as TokenCountEvent;
case InferenceChatCompletionEventType.ChatCompletionMessage:
// Convert to ChatCompletionMessageEvent
return {
@ -68,6 +58,7 @@ export function convertInferenceEventsToStreamingEvents(): OperatorFunction<
: undefined,
},
} as ChatCompletionMessageEvent;
default:
throw new Error(`Unknown event type`);
}

View file

@ -1,37 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { filter, OperatorFunction, scan, startWith } from 'rxjs';
import {
StreamingChatResponseEvent,
StreamingChatResponseEventType,
TokenCountEvent,
} from '../../../../common/conversation_complete';
export function extractTokenCount(): OperatorFunction<
StreamingChatResponseEvent,
TokenCountEvent['tokens']
> {
return (events$) => {
return events$.pipe(
filter(
(event): event is TokenCountEvent =>
event.type === StreamingChatResponseEventType.TokenCount
),
scan(
(acc, event) => {
acc.completion += event.tokens.completion;
acc.prompt += event.tokens.prompt;
acc.total += event.tokens.total;
return acc;
},
{ completion: 0, prompt: 0, total: 0 }
),
startWith({ completion: 0, prompt: 0, total: 0 })
);
};
}

View file

@ -9,7 +9,6 @@ import { ignoreElements, last, merge, Observable, shareReplay, tap } from 'rxjs'
import { createFunctionNotFoundError, FunctionDefinition } from '../../../../common';
import { ChatEvent } from '../../../../common/conversation_complete';
import { concatenateChatCompletionChunks } from '../../../../common/utils/concatenate_chat_completion_chunks';
import { withoutTokenCountEvents } from '../../../../common/utils/without_token_count_events';
export function failOnNonExistingFunctionCall({
functions,
@ -22,7 +21,6 @@ export function failOnNonExistingFunctionCall({
return merge(
shared$,
shared$.pipe(
withoutTokenCountEvents(),
concatenateChatCompletionChunks(),
last(),
tap((event) => {

View file

@ -4,9 +4,9 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { filter, lastValueFrom, of, throwError, toArray } from 'rxjs';
import { filter, lastValueFrom, of, throwError } from 'rxjs';
import { ChatCompleteResponse } from '@kbn/inference-common';
import { Message, MessageRole, StreamingChatResponseEventType } from '../../../../common';
import { Message, MessageRole } from '../../../../common';
import { LangTracer } from '../instrumentation/lang_tracer';
import { TITLE_CONVERSATION_FUNCTION_NAME, getGeneratedTitle } from './get_generated_title';
@ -118,44 +118,6 @@ describe('getGeneratedTitle', () => {
expect(await testTitle(`"User's request for a title"`)).toEqual(`User's request for a title`);
});
it('ignores token count events and still passes them through', async () => {
const { title$ } = callGenerateTitle([
{
content: '',
toolCalls: [
{
toolCallId: 'test_id',
function: {
name: 'title_conversation',
arguments: {
title: 'My title',
},
},
},
],
tokens: {
completion: 10,
prompt: 10,
total: 10,
},
},
]);
const events = await lastValueFrom(title$.pipe(toArray()));
expect(events).toEqual([
'My title',
{
tokens: {
completion: 10,
prompt: 10,
total: 10,
},
type: StreamingChatResponseEventType.TokenCount,
},
]);
});
it('handles errors in chat and falls back to the default title', async () => {
const chatSpy = jest
.fn()

View file

@ -9,8 +9,7 @@ import { catchError, mergeMap, Observable, of, tap, from } from 'rxjs';
import { Logger } from '@kbn/logging';
import { ChatCompleteResponse } from '@kbn/inference-common';
import type { ObservabilityAIAssistantClient } from '..';
import { Message, MessageRole, StreamingChatResponseEventType } from '../../../../common';
import { TokenCountEvent } from '../../../../common/conversation_complete';
import { Message, MessageRole } from '../../../../common';
import { LangTracer } from '../instrumentation/lang_tracer';
export const TITLE_CONVERSATION_FUNCTION_NAME = 'title_conversation';
@ -33,7 +32,7 @@ export function getGeneratedTitle({
chat: ChatFunctionWithoutConnectorAndTokenCount;
logger: Pick<Logger, 'debug' | 'error'>;
tracer: LangTracer;
}): Observable<string | TokenCountEvent> {
}): Observable<string> {
return from(
chat('generate_title', {
systemMessage:
@ -84,21 +83,7 @@ export function getGeneratedTitle({
// - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes
title = title.replace(/^"(.*)"$/g, '$1').replace(/^'(.*)'$/g, '$1');
const tokenCount: TokenCountEvent | undefined = response.tokens
? {
type: StreamingChatResponseEventType.TokenCount,
tokens: {
completion: response.tokens.completion,
prompt: response.tokens.prompt,
total: response.tokens.total,
},
}
: undefined;
const events: Array<string | TokenCountEvent> = [title];
if (tokenCount) events.push(tokenCount);
return from(events); // Emit each event separately
return of(title);
}),
tap((event) => {
if (typeof event === 'string') {

View file

@ -1,38 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { merge, Observable, partition } from 'rxjs';
import type { StreamingChatResponseEvent } from '../../../../common';
import {
StreamingChatResponseEventType,
TokenCountEvent,
} from '../../../../common/conversation_complete';
type Hide = <T extends StreamingChatResponseEvent>() => (
source$: Observable<T | TokenCountEvent>
) => Observable<Exclude<T, TokenCountEvent>>;
export function hideTokenCountEvents<T>(
cb: (hide: Hide) => Observable<Exclude<T, TokenCountEvent>>
): Observable<T | TokenCountEvent> {
// `hide` can be called multiple times, so we keep track of each invocation
const allInterceptors: Array<Observable<TokenCountEvent>> = [];
const hide: Hide = () => (source$) => {
const [tokenCountEvents$, otherEvents$] = partition(
source$,
(value): value is TokenCountEvent => value.type === StreamingChatResponseEventType.TokenCount
);
allInterceptors.push(tokenCountEvents$);
return otherEvents$;
};
// combine the two observables again
return merge(cb(hide), ...allInterceptors);
}

View file

@ -25,10 +25,6 @@ const dynamic = {
dynamic: true,
};
const integer = {
type: 'integer' as const,
};
export const conversationComponentTemplate: ClusterComponentTemplate['component_template']['template'] =
{
mappings: {
@ -59,13 +55,6 @@ export const conversationComponentTemplate: ClusterComponentTemplate['component_
id: keyword,
title: text,
last_updated: date,
token_count: {
properties: {
prompt: integer,
completion: integer,
total: integer,
},
},
},
},
namespace: keyword,

View file

@ -11,7 +11,6 @@ import {
BufferFlushEvent,
StreamingChatResponseEventType,
StreamingChatResponseEventWithoutError,
TokenCountEvent,
} from '../../../common/conversation_complete';
// The Cloud proxy currently buffers 4kb or 8kb of data until flushing.
@ -19,7 +18,7 @@ import {
// so we manually insert some data every 250ms if needed to force it
// to flush.
export function flushBuffer<T extends StreamingChatResponseEventWithoutError | TokenCountEvent>(
export function flushBuffer<T extends StreamingChatResponseEventWithoutError>(
isCloud: boolean
): OperatorFunction<T, T | BufferFlushEvent> {
return (source: Observable<T>) =>

View file

@ -24,11 +24,10 @@ import {
ChatCompletionChunkEvent,
StreamingChatResponseEventType,
StreamingChatResponseEventWithoutError,
TokenCountEvent,
} from '../../../common/conversation_complete';
export function observableIntoOpenAIStream(
source: Observable<StreamingChatResponseEventWithoutError | BufferFlushEvent | TokenCountEvent>,
source: Observable<StreamingChatResponseEventWithoutError | BufferFlushEvent>,
logger: Logger
) {
const stream = new PassThrough();

View file

@ -110,8 +110,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
await simulator.next(`Part: ${i}\n`);
}
await simulator.tokenCount({ completion: 20, prompt: 33, total: 53 });
await simulator.complete();
await new Promise<void>((innerResolve) => passThrough.on('end', () => innerResolve()));
@ -127,15 +125,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
2
)}`
);
const tokenCountChunk = receivedChunks.find((chunk) => chunk.type === 'tokenCount');
expect(tokenCountChunk).to.eql(
{
type: 'tokenCount',
tokens: { completion: 20, prompt: 33, total: 53 },
},
`received token count chunk did not match expected`
);
}
runTest().then(resolve, reject);

View file

@ -12,7 +12,6 @@ import expect from '@kbn/expect';
import {
ChatCompletionChunkEvent,
ConversationCreateEvent,
ConversationUpdateEvent,
MessageAddEvent,
StreamingChatResponseEvent,
StreamingChatResponseEventType,
@ -25,7 +24,7 @@ import {
LlmResponseSimulator,
} from '../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy';
import { createOpenAiChunk } from '../../../../../../observability_ai_assistant_api_integration/common/create_openai_chunk';
import { decodeEvents, getConversationCreatedEvent, getConversationUpdatedEvent } from '../helpers';
import { decodeEvents, getConversationCreatedEvent } from '../helpers';
import type { DeploymentAgnosticFtrProviderContext } from '../../../../ftr_provider_context';
import { SupertestWithRoleScope } from '../../../../services/role_scoped_supertest';
@ -104,7 +103,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
],
});
await titleSimulator.tokenCount({ completion: 5, prompt: 10, total: 15 });
await titleSimulator.complete();
await conversationSimulator.status(200);
@ -170,7 +168,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
await simulator.rawWrite(`data: ${chunk.substring(0, 10)}`);
await simulator.rawWrite(`${chunk.substring(10)}\n\n`);
await simulator.tokenCount({ completion: 20, prompt: 33, total: 53 });
await simulator.complete();
await new Promise<void>((resolve) => passThrough.on('end', () => resolve()));
@ -253,7 +250,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
events = await getEvents({}, async (conversationSimulator) => {
await conversationSimulator.next('Hello');
await conversationSimulator.next(' again');
await conversationSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
await conversationSimulator.complete();
}).then((_events) => {
return _events.filter(
@ -296,26 +292,12 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
});
expect(
omit(
events[4],
'conversation.id',
'conversation.last_updated',
'conversation.token_count'
)
).to.eql({
expect(omit(events[4], 'conversation.id', 'conversation.last_updated')).to.eql({
type: StreamingChatResponseEventType.ConversationCreate,
conversation: {
title: 'My generated title',
},
});
const tokenCount = (events[4] as ConversationCreateEvent).conversation.token_count!;
expect(tokenCount.completion).to.be.greaterThan(0);
expect(tokenCount.prompt).to.be.greaterThan(0);
expect(tokenCount.total).to.eql(tokenCount.completion + tokenCount.prompt);
});
after(async () => {
@ -375,7 +357,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
],
});
await conversationSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
await conversationSimulator.complete();
}
);
@ -422,7 +403,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
describe('when updating an existing conversation', () => {
let conversationCreatedEvent: ConversationCreateEvent;
let conversationUpdatedEvent: ConversationUpdateEvent;
before(async () => {
void proxy
@ -499,8 +479,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
expect(updatedResponse.status).to.be(200);
await proxy.waitForAllInterceptorsSettled();
conversationUpdatedEvent = getConversationUpdatedEvent(updatedResponse.body);
});
after(async () => {
@ -515,18 +493,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
expect(status).to.be(200);
});
it('has correct token count for a new conversation', async () => {
expect(conversationCreatedEvent.conversation.token_count?.completion).to.be.greaterThan(0);
expect(conversationCreatedEvent.conversation.token_count?.prompt).to.be.greaterThan(0);
expect(conversationCreatedEvent.conversation.token_count?.total).to.be.greaterThan(0);
});
it('has correct token count for the updated conversation', async () => {
expect(conversationUpdatedEvent.conversation.token_count!.total).to.be.greaterThan(
conversationCreatedEvent.conversation.token_count!.total
);
});
});
// todo

View file

@ -88,7 +88,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
await titleSimulator.status(200);
await titleSimulator.next('My generated title');
await titleSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
await titleSimulator.complete();
await conversationSimulator.status(200);
@ -173,7 +172,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
],
});
await conversationSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
await conversationSimulator.complete();
}
);
@ -245,7 +243,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
},
],
});
await conversationSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
await conversationSimulator.complete();
}
);
@ -262,7 +259,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
before(async () => {
responseBody = await getOpenAIResponse(async (conversationSimulator) => {
await conversationSimulator.next('Hello');
await conversationSimulator.tokenCount({ completion: 5, prompt: 10, total: 15 });
await conversationSimulator.complete();
});
});

View file

@ -45,7 +45,6 @@ export interface LlmResponseSimulator {
}>;
}
) => Promise<void>;
tokenCount: (msg: { completion: number; prompt: number; total: number }) => Promise<void>;
error: (error: any) => Promise<void>;
complete: () => Promise<void>;
rawWrite: (chunk: string) => Promise<void>;
@ -166,17 +165,6 @@ export class LlmProxy {
Connection: 'keep-alive',
});
}),
tokenCount: (msg) => {
const chunk = {
object: 'chat.completion.chunk',
usage: {
completion_tokens: msg.completion,
prompt_tokens: msg.prompt,
total_tokens: msg.total,
},
};
return write(`data: ${JSON.stringify(chunk)}\n\n`);
},
next: (msg) => {
const chunk = createOpenAiChunk(msg);
return write(`data: ${JSON.stringify(chunk)}\n\n`);
@ -220,7 +208,7 @@ export class LlmProxy {
for (const chunk of parsedChunks) {
await simulator.next(chunk);
}
await simulator.tokenCount({ completion: 1, prompt: 1, total: 1 });
await simulator.complete();
},
} as any;

View file

@ -137,11 +137,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
],
conversation: {
title: 'My old conversation',
token_count: {
completion: 1,
prompt: 1,
total: 2,
},
},
'@timestamp': '2024-04-18T14:29:22.948',
public: false,
@ -278,14 +273,10 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
],
});
await titleSimulator.tokenCount({ completion: 1, prompt: 1, total: 2 });
await titleSimulator.complete();
await conversationSimulator.next('My response');
await conversationSimulator.tokenCount({ completion: 1, prompt: 1, total: 2 });
await conversationSimulator.complete();
await header.waitUntilLoadingHasFinished();
@ -350,8 +341,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
await conversationSimulator.next('My second response');
await conversationSimulator.tokenCount({ completion: 1, prompt: 1, total: 2 });
await conversationSimulator.complete();
await header.waitUntilLoadingHasFinished();
@ -445,8 +434,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
'Service Level Indicators (SLIs) are quantifiable defined metrics that measure the performance and availability of a service or distributed system.'
);
await conversationSimulator.tokenCount({ completion: 1, prompt: 1, total: 2 });
await conversationSimulator.complete();
await header.waitUntilLoadingHasFinished();