mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 17:59:23 -04:00
[Obs AI Assistant] Remove TokenCountEvent (#209549)
Closes https://github.com/elastic/kibana/issues/205479 This filters out the `ChatCompletionTokenCountEvent` from the inference plugin. This greatly simplifies handling ChatCompletion events in the Obs AI Assistant.
This commit is contained in:
parent
283cb29606
commit
c4826bdfbf
26 changed files with 96 additions and 541 deletions
|
@ -6,7 +6,7 @@
|
|||
*/
|
||||
|
||||
import { i18n } from '@kbn/i18n';
|
||||
import { TokenCount as TokenCountType, type Message } from './types';
|
||||
import { type Message } from './types';
|
||||
|
||||
export enum StreamingChatResponseEventType {
|
||||
ChatCompletionChunk = 'chatCompletionChunk',
|
||||
|
@ -16,7 +16,6 @@ export enum StreamingChatResponseEventType {
|
|||
MessageAdd = 'messageAdd',
|
||||
ChatCompletionError = 'chatCompletionError',
|
||||
BufferFlush = 'bufferFlush',
|
||||
TokenCount = 'tokenCount',
|
||||
}
|
||||
|
||||
type StreamingChatResponseEventBase<
|
||||
|
@ -54,7 +53,6 @@ export type ConversationCreateEvent = StreamingChatResponseEventBase<
|
|||
id: string;
|
||||
title: string;
|
||||
last_updated: string;
|
||||
token_count?: TokenCountType;
|
||||
};
|
||||
}
|
||||
>;
|
||||
|
@ -66,7 +64,6 @@ export type ConversationUpdateEvent = StreamingChatResponseEventBase<
|
|||
id: string;
|
||||
title: string;
|
||||
last_updated: string;
|
||||
token_count?: TokenCountType;
|
||||
};
|
||||
}
|
||||
>;
|
||||
|
@ -95,17 +92,6 @@ export type BufferFlushEvent = StreamingChatResponseEventBase<
|
|||
}
|
||||
>;
|
||||
|
||||
export type TokenCountEvent = StreamingChatResponseEventBase<
|
||||
StreamingChatResponseEventType.TokenCount,
|
||||
{
|
||||
tokens: {
|
||||
completion: number;
|
||||
prompt: number;
|
||||
total: number;
|
||||
};
|
||||
}
|
||||
>;
|
||||
|
||||
export type StreamingChatResponseEvent =
|
||||
| ChatCompletionChunkEvent
|
||||
| ChatCompletionMessageEvent
|
||||
|
@ -113,7 +99,6 @@ export type StreamingChatResponseEvent =
|
|||
| ConversationUpdateEvent
|
||||
| MessageAddEvent
|
||||
| ChatCompletionErrorEvent
|
||||
| TokenCountEvent
|
||||
| BufferFlushEvent;
|
||||
|
||||
export type StreamingChatResponseEventWithoutError = Exclude<
|
||||
|
@ -121,7 +106,7 @@ export type StreamingChatResponseEventWithoutError = Exclude<
|
|||
ChatCompletionErrorEvent
|
||||
>;
|
||||
|
||||
export type ChatEvent = ChatCompletionChunkEvent | TokenCountEvent | ChatCompletionMessageEvent;
|
||||
export type ChatEvent = ChatCompletionChunkEvent | ChatCompletionMessageEvent;
|
||||
export type MessageOrChatEvent = ChatEvent | MessageAddEvent;
|
||||
|
||||
export enum ChatCompletionErrorCode {
|
||||
|
|
|
@ -18,7 +18,6 @@ export {
|
|||
export type {
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionMessageEvent,
|
||||
TokenCountEvent,
|
||||
ConversationCreateEvent,
|
||||
ConversationUpdateEvent,
|
||||
MessageAddEvent,
|
||||
|
|
|
@ -45,12 +45,6 @@ export interface Message {
|
|||
};
|
||||
}
|
||||
|
||||
export interface TokenCount {
|
||||
prompt: number;
|
||||
completion: number;
|
||||
total: number;
|
||||
}
|
||||
|
||||
export interface Conversation {
|
||||
'@timestamp': string;
|
||||
user?: {
|
||||
|
@ -61,7 +55,6 @@ export interface Conversation {
|
|||
id: string;
|
||||
title: string;
|
||||
last_updated: string;
|
||||
token_count?: TokenCount;
|
||||
};
|
||||
systemMessage?: string;
|
||||
messages: Message[];
|
||||
|
@ -71,8 +64,8 @@ export interface Conversation {
|
|||
public: boolean;
|
||||
}
|
||||
|
||||
export type ConversationRequestBase = Omit<Conversation, 'user' | 'conversation' | 'namespace'> & {
|
||||
conversation: { title: string; token_count?: TokenCount; id?: string };
|
||||
type ConversationRequestBase = Omit<Conversation, 'user' | 'conversation' | 'namespace'> & {
|
||||
conversation: { title: string; id?: string };
|
||||
};
|
||||
|
||||
export type ConversationCreateRequest = ConversationRequestBase;
|
||||
|
|
|
@ -16,7 +16,6 @@ import {
|
|||
withLatestFrom,
|
||||
filter,
|
||||
} from 'rxjs';
|
||||
import { withoutTokenCountEvents } from './without_token_count_events';
|
||||
import {
|
||||
type ChatCompletionChunkEvent,
|
||||
ChatEvent,
|
||||
|
@ -69,15 +68,12 @@ export function emitWithConcatenatedMessage<T extends ChatEvent>(
|
|||
return (source$) => {
|
||||
const shared = source$.pipe(shareReplay());
|
||||
|
||||
const withoutTokenCount$ = shared.pipe(filterChunkEvents());
|
||||
|
||||
const response$ = concat(
|
||||
shared,
|
||||
shared.pipe(
|
||||
withoutTokenCountEvents(),
|
||||
concatenateChatCompletionChunks(),
|
||||
last(),
|
||||
withLatestFrom(withoutTokenCount$),
|
||||
withLatestFrom(shared.pipe(filterChunkEvents())),
|
||||
mergeMap(([message, chunkEvent]) => {
|
||||
return mergeWithEditedMessage(message, chunkEvent, callback);
|
||||
})
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { filter, OperatorFunction } from 'rxjs';
|
||||
import {
|
||||
StreamingChatResponseEvent,
|
||||
StreamingChatResponseEventType,
|
||||
TokenCountEvent,
|
||||
} from '../conversation_complete';
|
||||
|
||||
export function withoutTokenCountEvents<T extends StreamingChatResponseEvent>(): OperatorFunction<
|
||||
T,
|
||||
Exclude<T, TokenCountEvent>
|
||||
> {
|
||||
return filter(
|
||||
(event): event is Exclude<T, TokenCountEvent> =>
|
||||
event.type !== StreamingChatResponseEventType.TokenCount
|
||||
);
|
||||
}
|
|
@ -65,28 +65,6 @@ export const chatFeedbackEventSchema: EventTypeOpts<ChatFeedback> = {
|
|||
description: 'The timestamp of the last message in the conversation.',
|
||||
},
|
||||
},
|
||||
token_count: {
|
||||
properties: {
|
||||
completion: {
|
||||
type: 'long',
|
||||
_meta: {
|
||||
description: 'The number of tokens in the completion.',
|
||||
},
|
||||
},
|
||||
prompt: {
|
||||
type: 'long',
|
||||
_meta: {
|
||||
description: 'The number of tokens in the prompt.',
|
||||
},
|
||||
},
|
||||
total: {
|
||||
type: 'long',
|
||||
_meta: {
|
||||
description: 'The total number of tokens in the conversation.',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
labels: {
|
||||
|
|
|
@ -13,7 +13,6 @@ import { Readable } from 'stream';
|
|||
import { AssistantScope } from '@kbn/ai-assistant-common';
|
||||
import { aiAssistantSimulatedFunctionCalling } from '../..';
|
||||
import { createFunctionResponseMessage } from '../../../common/utils/create_function_response_message';
|
||||
import { withoutTokenCountEvents } from '../../../common/utils/without_token_count_events';
|
||||
import { LangTracer } from '../../service/client/instrumentation/lang_tracer';
|
||||
import { flushBuffer } from '../../service/util/flush_buffer';
|
||||
import { observableIntoOpenAIStream } from '../../service/util/observable_into_openai_stream';
|
||||
|
@ -203,16 +202,14 @@ const chatRecallRoute = createObservabilityAIAssistantServerRoute({
|
|||
recallAndScore({
|
||||
analytics: (await resources.plugins.core.start()).analytics,
|
||||
chat: (name, params) =>
|
||||
client
|
||||
.chat(name, {
|
||||
...params,
|
||||
stream: true,
|
||||
connectorId,
|
||||
simulateFunctionCalling,
|
||||
signal,
|
||||
tracer: new LangTracer(otelContext.active()),
|
||||
})
|
||||
.pipe(withoutTokenCountEvents()),
|
||||
client.chat(name, {
|
||||
...params,
|
||||
stream: true,
|
||||
connectorId,
|
||||
simulateFunctionCalling,
|
||||
signal,
|
||||
tracer: new LangTracer(otelContext.active()),
|
||||
}),
|
||||
context,
|
||||
logger: resources.logger,
|
||||
messages: [],
|
||||
|
|
|
@ -7,9 +7,7 @@
|
|||
import * as t from 'io-ts';
|
||||
import { toBooleanRt } from '@kbn/io-ts-utils';
|
||||
import {
|
||||
type Conversation,
|
||||
type ConversationCreateRequest,
|
||||
type ConversationRequestBase,
|
||||
type ConversationUpdateRequest,
|
||||
type Message,
|
||||
MessageRole,
|
||||
|
@ -57,17 +55,12 @@ const tokenCountRt = t.type({
|
|||
total: t.number,
|
||||
});
|
||||
|
||||
export const baseConversationRt: t.Type<ConversationRequestBase> = t.intersection([
|
||||
export const conversationCreateRt: t.Type<ConversationCreateRequest> = t.intersection([
|
||||
t.type({
|
||||
'@timestamp': t.string,
|
||||
conversation: t.intersection([
|
||||
t.type({
|
||||
title: t.string,
|
||||
}),
|
||||
t.partial({
|
||||
token_count: tokenCountRt,
|
||||
}),
|
||||
]),
|
||||
conversation: t.type({
|
||||
title: t.string,
|
||||
}),
|
||||
messages: t.array(messageRt),
|
||||
labels: t.record(t.string, t.string),
|
||||
numeric_labels: t.record(t.string, t.number),
|
||||
|
@ -84,17 +77,8 @@ export const assistantScopeType = t.union([
|
|||
t.literal('all'),
|
||||
]);
|
||||
|
||||
export const conversationCreateRt: t.Type<ConversationCreateRequest> = t.intersection([
|
||||
baseConversationRt,
|
||||
t.type({
|
||||
conversation: t.type({
|
||||
title: t.string,
|
||||
}),
|
||||
}),
|
||||
]);
|
||||
|
||||
export const conversationUpdateRt: t.Type<ConversationUpdateRequest> = t.intersection([
|
||||
baseConversationRt,
|
||||
conversationCreateRt,
|
||||
t.type({
|
||||
conversation: t.intersection([
|
||||
t.type({
|
||||
|
@ -102,33 +86,12 @@ export const conversationUpdateRt: t.Type<ConversationUpdateRequest> = t.interse
|
|||
title: t.string,
|
||||
}),
|
||||
t.partial({
|
||||
token_count: tokenCountRt,
|
||||
token_count: tokenCountRt, // deprecated, but kept for backwards compatibility
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
]);
|
||||
|
||||
export const conversationRt: t.Type<Conversation> = t.intersection([
|
||||
baseConversationRt,
|
||||
t.intersection([
|
||||
t.type({
|
||||
namespace: t.string,
|
||||
conversation: t.intersection([
|
||||
t.type({
|
||||
id: t.string,
|
||||
last_updated: t.string,
|
||||
}),
|
||||
t.partial({
|
||||
token_count: tokenCountRt,
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
t.partial({
|
||||
user: t.intersection([t.type({ name: t.string }), t.partial({ id: t.string })]),
|
||||
}),
|
||||
]),
|
||||
]);
|
||||
|
||||
export const functionRt = t.intersection([
|
||||
t.type({
|
||||
name: t.string,
|
||||
|
|
|
@ -69,21 +69,6 @@ function createLlmSimulator(subscriber: any) {
|
|||
toolCalls: msg.function_call ? [{ function: msg.function_call }] : [],
|
||||
});
|
||||
},
|
||||
tokenCount: async ({
|
||||
completion,
|
||||
prompt,
|
||||
total,
|
||||
}: {
|
||||
completion: number;
|
||||
prompt: number;
|
||||
total: number;
|
||||
}) => {
|
||||
subscriber.next({
|
||||
type: InferenceChatCompletionEventType.ChatCompletionTokenCount,
|
||||
tokens: { completion, prompt, total },
|
||||
});
|
||||
subscriber.complete();
|
||||
},
|
||||
chunk: async (msg: ChunkDelta) => {
|
||||
subscriber.next({
|
||||
type: InferenceChatCompletionEventType.ChatCompletionChunk,
|
||||
|
@ -252,11 +237,6 @@ describe('Observability AI Assistant client', () => {
|
|||
},
|
||||
},
|
||||
],
|
||||
tokens: {
|
||||
completion: 0,
|
||||
prompt: 0,
|
||||
total: 0,
|
||||
},
|
||||
})
|
||||
.catch((error) => titleLlmSimulator.error(error));
|
||||
};
|
||||
|
@ -388,7 +368,6 @@ describe('Observability AI Assistant client', () => {
|
|||
titleLlmPromiseReject(new Error('Failed generating title'));
|
||||
|
||||
await nextTick();
|
||||
await llmSimulator.tokenCount({ completion: 1, prompt: 33, total: 34 });
|
||||
await llmSimulator.complete();
|
||||
|
||||
await finished(stream);
|
||||
|
@ -400,11 +379,6 @@ describe('Observability AI Assistant client', () => {
|
|||
title: 'New conversation',
|
||||
id: expect.any(String),
|
||||
last_updated: expect.any(String),
|
||||
token_count: {
|
||||
completion: 1,
|
||||
prompt: 33,
|
||||
total: 34,
|
||||
},
|
||||
},
|
||||
type: StreamingChatResponseEventType.ConversationCreate,
|
||||
});
|
||||
|
@ -418,7 +392,6 @@ describe('Observability AI Assistant client', () => {
|
|||
await llmSimulator.chunk({ content: ' again' });
|
||||
|
||||
titleLlmPromiseResolve('An auto-generated title');
|
||||
await llmSimulator.tokenCount({ completion: 6, prompt: 210, total: 216 });
|
||||
await llmSimulator.complete();
|
||||
|
||||
await finished(stream);
|
||||
|
@ -457,11 +430,6 @@ describe('Observability AI Assistant client', () => {
|
|||
title: 'An auto-generated title',
|
||||
id: expect.any(String),
|
||||
last_updated: expect.any(String),
|
||||
token_count: {
|
||||
completion: 6,
|
||||
prompt: 210,
|
||||
total: 216,
|
||||
},
|
||||
},
|
||||
type: StreamingChatResponseEventType.ConversationCreate,
|
||||
});
|
||||
|
@ -475,11 +443,6 @@ describe('Observability AI Assistant client', () => {
|
|||
id: expect.any(String),
|
||||
last_updated: expect.any(String),
|
||||
title: 'An auto-generated title',
|
||||
token_count: {
|
||||
completion: 6,
|
||||
prompt: 210,
|
||||
total: 216,
|
||||
},
|
||||
},
|
||||
labels: {},
|
||||
numeric_labels: {},
|
||||
|
@ -543,11 +506,6 @@ describe('Observability AI Assistant client', () => {
|
|||
id: 'my-conversation-id',
|
||||
title: 'My stored conversation',
|
||||
last_updated: new Date().toISOString(),
|
||||
token_count: {
|
||||
completion: 1,
|
||||
prompt: 78,
|
||||
total: 79,
|
||||
},
|
||||
},
|
||||
labels: {},
|
||||
numeric_labels: {},
|
||||
|
@ -583,7 +541,6 @@ describe('Observability AI Assistant client', () => {
|
|||
|
||||
await llmSimulator.chunk({ content: 'Hello' });
|
||||
await llmSimulator.next({ content: 'Hello' });
|
||||
await llmSimulator.tokenCount({ completion: 1, prompt: 33, total: 34 });
|
||||
await llmSimulator.complete();
|
||||
|
||||
await finished(stream);
|
||||
|
@ -595,11 +552,6 @@ describe('Observability AI Assistant client', () => {
|
|||
title: 'My stored conversation',
|
||||
id: expect.any(String),
|
||||
last_updated: expect.any(String),
|
||||
token_count: {
|
||||
completion: 2,
|
||||
prompt: 111,
|
||||
total: 113,
|
||||
},
|
||||
},
|
||||
type: StreamingChatResponseEventType.ConversationUpdate,
|
||||
});
|
||||
|
@ -614,11 +566,6 @@ describe('Observability AI Assistant client', () => {
|
|||
id: expect.any(String),
|
||||
last_updated: expect.any(String),
|
||||
title: 'My stored conversation',
|
||||
token_count: {
|
||||
completion: 2,
|
||||
prompt: 111,
|
||||
total: 113,
|
||||
},
|
||||
},
|
||||
labels: {},
|
||||
numeric_labels: {},
|
||||
|
@ -900,7 +847,6 @@ describe('Observability AI Assistant client', () => {
|
|||
beforeEach(async () => {
|
||||
await llmSimulator.chunk({ content: 'I am done here' });
|
||||
await llmSimulator.next({ content: 'I am done here' });
|
||||
await llmSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
|
||||
await llmSimulator.complete();
|
||||
await waitForNextWrite(stream);
|
||||
|
||||
|
@ -940,11 +886,6 @@ describe('Observability AI Assistant client', () => {
|
|||
id: expect.any(String),
|
||||
last_updated: expect.any(String),
|
||||
title: 'My predefined title',
|
||||
token_count: {
|
||||
completion: expect.any(Number),
|
||||
prompt: expect.any(Number),
|
||||
total: expect.any(Number),
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
|
@ -1200,7 +1141,6 @@ describe('Observability AI Assistant client', () => {
|
|||
|
||||
await llmSimulator.chunk({ content: 'Hello' });
|
||||
await llmSimulator.next({ content: 'Hello' });
|
||||
await llmSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
|
||||
await llmSimulator.complete();
|
||||
|
||||
await finished(stream);
|
||||
|
@ -1492,7 +1432,6 @@ describe('Observability AI Assistant client', () => {
|
|||
|
||||
await llmSimulator.chunk({ function_call: { name: 'get_top_alerts' } });
|
||||
await llmSimulator.next({ content: 'done' });
|
||||
await llmSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
|
||||
await llmSimulator.complete();
|
||||
|
||||
await waitFor(() => functionResponsePromiseResolve !== undefined);
|
||||
|
@ -1612,7 +1551,6 @@ describe('Observability AI Assistant client', () => {
|
|||
function_call: { name: 'my_action', arguments: JSON.stringify({ foo: 'bar' }) },
|
||||
});
|
||||
await llmSimulator.next({ content: 'content' });
|
||||
await llmSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
|
||||
await llmSimulator.complete();
|
||||
});
|
||||
|
||||
|
@ -1643,7 +1581,6 @@ describe('Observability AI Assistant client', () => {
|
|||
await llmSimulator.next({
|
||||
content: 'Looks like the function call failed',
|
||||
});
|
||||
await llmSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
|
||||
|
||||
await llmSimulator.complete();
|
||||
});
|
||||
|
|
|
@ -41,7 +41,6 @@ import {
|
|||
ConversationUpdateEvent,
|
||||
createConversationNotFoundError,
|
||||
StreamingChatResponseEventType,
|
||||
TokenCountEvent,
|
||||
type StreamingChatResponseEvent,
|
||||
} from '../../../common/conversation_complete';
|
||||
import { convertMessagesForInference } from '../../../common/convert_messages_for_inference';
|
||||
|
@ -56,7 +55,6 @@ import {
|
|||
KnowledgeBaseType,
|
||||
KnowledgeBaseEntryRole,
|
||||
} from '../../../common/types';
|
||||
import { withoutTokenCountEvents } from '../../../common/utils/without_token_count_events';
|
||||
import { CONTEXT_FUNCTION_NAME } from '../../functions/context';
|
||||
import type { ChatFunctionClient } from '../chat_function_client';
|
||||
import { KnowledgeBaseService, RecalledEntry } from '../knowledge_base_service';
|
||||
|
@ -68,9 +66,7 @@ import { LangTracer } from './instrumentation/lang_tracer';
|
|||
import { continueConversation } from './operators/continue_conversation';
|
||||
import { convertInferenceEventsToStreamingEvents } from './operators/convert_inference_events_to_streaming_events';
|
||||
import { extractMessages } from './operators/extract_messages';
|
||||
import { extractTokenCount } from './operators/extract_token_count';
|
||||
import { getGeneratedTitle } from './operators/get_generated_title';
|
||||
import { instrumentAndCountTokens } from './operators/instrument_and_count_tokens';
|
||||
import {
|
||||
reIndexKnowledgeBaseAndPopulateSemanticTextField,
|
||||
scheduleKbSemanticTextMigrationTask,
|
||||
|
@ -78,6 +74,7 @@ import {
|
|||
import { ObservabilityAIAssistantPluginStartDependencies } from '../../types';
|
||||
import { ObservabilityAIAssistantConfig } from '../../config';
|
||||
import { getElserModelId } from '../knowledge_base_service/get_elser_model_id';
|
||||
import { apmInstrumentation } from './operators/apm_instrumentation';
|
||||
|
||||
const MAX_FUNCTION_CALLS = 8;
|
||||
|
||||
|
@ -297,17 +294,12 @@ export class ObservabilityAIAssistantClient {
|
|||
// wait until all dependencies have completed
|
||||
forkJoin([
|
||||
// get just the new messages
|
||||
nextEvents$.pipe(withoutTokenCountEvents(), extractMessages()),
|
||||
// count all the token count events emitted during completion
|
||||
mergeOperator(
|
||||
nextEvents$,
|
||||
title$.pipe(filter((value): value is TokenCountEvent => typeof value !== 'string'))
|
||||
).pipe(extractTokenCount()),
|
||||
nextEvents$.pipe(extractMessages()),
|
||||
// get just the title, and drop the token count events
|
||||
title$.pipe(filter((value): value is string => typeof value === 'string')),
|
||||
systemMessage$,
|
||||
]).pipe(
|
||||
switchMap(([addedMessages, tokenCountResult, title, systemMessage]) => {
|
||||
switchMap(([addedMessages, title, systemMessage]) => {
|
||||
const initialMessagesWithAddedMessages = initialMessages.concat(addedMessages);
|
||||
|
||||
const lastMessage = last(initialMessagesWithAddedMessages);
|
||||
|
@ -329,13 +321,6 @@ export class ObservabilityAIAssistantClient {
|
|||
return throwError(() => createConversationNotFoundError());
|
||||
}
|
||||
|
||||
const persistedTokenCount = conversation._source?.conversation
|
||||
.token_count ?? {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
};
|
||||
|
||||
return from(
|
||||
this.update(
|
||||
conversationId,
|
||||
|
@ -349,16 +334,10 @@ export class ObservabilityAIAssistantClient {
|
|||
// update messages and system message
|
||||
{ messages: initialMessagesWithAddedMessages, systemMessage },
|
||||
|
||||
// update token count
|
||||
// update title
|
||||
{
|
||||
conversation: {
|
||||
title: title || conversation._source?.conversation.title,
|
||||
token_count: {
|
||||
prompt: persistedTokenCount.prompt + tokenCountResult.prompt,
|
||||
completion:
|
||||
persistedTokenCount.completion + tokenCountResult.completion,
|
||||
total: persistedTokenCount.total + tokenCountResult.total,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
@ -382,7 +361,6 @@ export class ObservabilityAIAssistantClient {
|
|||
conversation: {
|
||||
title,
|
||||
id: conversationId,
|
||||
token_count: tokenCountResult,
|
||||
},
|
||||
public: !!isPublic,
|
||||
labels: {},
|
||||
|
@ -403,8 +381,7 @@ export class ObservabilityAIAssistantClient {
|
|||
);
|
||||
|
||||
return output$.pipe(
|
||||
instrumentAndCountTokens('complete'),
|
||||
withoutTokenCountEvents(),
|
||||
apmInstrumentation('complete'),
|
||||
catchError((error) => {
|
||||
this.dependencies.logger.error(error);
|
||||
return throwError(() => error);
|
||||
|
@ -462,7 +439,7 @@ export class ObservabilityAIAssistantClient {
|
|||
stream: TStream;
|
||||
}
|
||||
): TStream extends true
|
||||
? Observable<ChatCompletionChunkEvent | TokenCountEvent | ChatCompletionMessageEvent>
|
||||
? Observable<ChatCompletionChunkEvent | ChatCompletionMessageEvent>
|
||||
: Promise<ChatCompleteResponse> {
|
||||
let tools: Record<string, { description: string; schema: any }> | undefined;
|
||||
let toolChoice: ToolChoiceType | { function: string } | undefined;
|
||||
|
@ -500,7 +477,7 @@ export class ObservabilityAIAssistantClient {
|
|||
})
|
||||
).pipe(
|
||||
convertInferenceEventsToStreamingEvents(),
|
||||
instrumentAndCountTokens(name),
|
||||
apmInstrumentation(name),
|
||||
failOnNonExistingFunctionCall({ functions }),
|
||||
tap((event) => {
|
||||
if (
|
||||
|
@ -512,7 +489,7 @@ export class ObservabilityAIAssistantClient {
|
|||
}),
|
||||
shareReplay()
|
||||
) as TStream extends true
|
||||
? Observable<ChatCompletionChunkEvent | TokenCountEvent | ChatCompletionMessageEvent>
|
||||
? Observable<ChatCompletionChunkEvent | ChatCompletionMessageEvent>
|
||||
: never;
|
||||
} else {
|
||||
return this.dependencies.inferenceClient.chatComplete({
|
||||
|
|
|
@ -18,9 +18,8 @@ import {
|
|||
finalize,
|
||||
} from 'rxjs';
|
||||
import type { StreamingChatResponseEvent } from '../../../../common/conversation_complete';
|
||||
import { extractTokenCount } from './extract_token_count';
|
||||
|
||||
export function instrumentAndCountTokens<T extends StreamingChatResponseEvent>(
|
||||
export function apmInstrumentation<T extends StreamingChatResponseEvent>(
|
||||
name: string
|
||||
): OperatorFunction<T, T> {
|
||||
return (source$) => {
|
||||
|
@ -35,19 +34,9 @@ export function instrumentAndCountTokens<T extends StreamingChatResponseEvent>(
|
|||
|
||||
const shared$ = source$.pipe(shareReplay());
|
||||
|
||||
let tokenCount = {
|
||||
prompt: 0,
|
||||
completion: 0,
|
||||
total: 0,
|
||||
};
|
||||
|
||||
return merge(
|
||||
shared$,
|
||||
shared$.pipe(
|
||||
extractTokenCount(),
|
||||
tap((nextTokenCount) => {
|
||||
tokenCount = nextTokenCount;
|
||||
}),
|
||||
last(),
|
||||
tap(() => {
|
||||
span?.setOutcome('success');
|
||||
|
@ -57,11 +46,6 @@ export function instrumentAndCountTokens<T extends StreamingChatResponseEvent>(
|
|||
return throwError(() => error);
|
||||
}),
|
||||
finalize(() => {
|
||||
span?.addLabels({
|
||||
tokenCountPrompt: tokenCount.prompt,
|
||||
tokenCountCompletion: tokenCount.completion,
|
||||
tokenCountTotal: tokenCount.total,
|
||||
});
|
||||
span?.end();
|
||||
}),
|
||||
ignoreElements()
|
|
@ -31,14 +31,12 @@ import { FunctionVisibility } from '../../../../common/functions/types';
|
|||
import { AdHocInstruction, Instruction } from '../../../../common/types';
|
||||
import { createFunctionResponseMessage } from '../../../../common/utils/create_function_response_message';
|
||||
import { emitWithConcatenatedMessage } from '../../../../common/utils/emit_with_concatenated_message';
|
||||
import { withoutTokenCountEvents } from '../../../../common/utils/without_token_count_events';
|
||||
import type { ChatFunctionClient } from '../../chat_function_client';
|
||||
import type { AutoAbortedChatFunction } from '../../types';
|
||||
import { createServerSideFunctionResponseError } from '../../util/create_server_side_function_response_error';
|
||||
import { LangTracer } from '../instrumentation/lang_tracer';
|
||||
import { catchFunctionNotFoundError } from './catch_function_not_found_error';
|
||||
import { extractMessages } from './extract_messages';
|
||||
import { hideTokenCountEvents } from './hide_token_count_events';
|
||||
|
||||
const MAX_FUNCTION_RESPONSE_TOKEN_COUNT = 4000;
|
||||
|
||||
|
@ -68,67 +66,65 @@ function executeFunctionAndCatchError({
|
|||
// hide token count events from functions to prevent them from
|
||||
// having to deal with it as well
|
||||
|
||||
return tracer.startActiveSpan(`execute_function ${name}`, ({ tracer: nextTracer }) =>
|
||||
hideTokenCountEvents((hide) => {
|
||||
const executeFunctionResponse$ = from(
|
||||
functionClient.executeFunction({
|
||||
name,
|
||||
chat: (operationName, params) => {
|
||||
return chat(operationName, {
|
||||
...params,
|
||||
tracer: nextTracer,
|
||||
connectorId,
|
||||
}).pipe(hide());
|
||||
},
|
||||
args,
|
||||
signal,
|
||||
messages,
|
||||
connectorId,
|
||||
simulateFunctionCalling,
|
||||
})
|
||||
);
|
||||
return tracer.startActiveSpan(`execute_function ${name}`, ({ tracer: nextTracer }) => {
|
||||
const executeFunctionResponse$ = from(
|
||||
functionClient.executeFunction({
|
||||
name,
|
||||
chat: (operationName, params) => {
|
||||
return chat(operationName, {
|
||||
...params,
|
||||
tracer: nextTracer,
|
||||
connectorId,
|
||||
});
|
||||
},
|
||||
args,
|
||||
signal,
|
||||
messages,
|
||||
connectorId,
|
||||
simulateFunctionCalling,
|
||||
})
|
||||
);
|
||||
|
||||
return executeFunctionResponse$.pipe(
|
||||
catchError((error) => {
|
||||
logger.error(`Encountered error running function ${name}: ${JSON.stringify(error)}`);
|
||||
// We want to catch the error only when a promise occurs
|
||||
// if it occurs in the Observable, we cannot easily recover
|
||||
// from it because the function may have already emitted
|
||||
// values which could lead to an invalid conversation state,
|
||||
// so in that case we let the stream fail.
|
||||
return of(createServerSideFunctionResponseError({ name, error }));
|
||||
}),
|
||||
switchMap((response) => {
|
||||
if (isObservable(response)) {
|
||||
return response;
|
||||
}
|
||||
return executeFunctionResponse$.pipe(
|
||||
catchError((error) => {
|
||||
logger.error(`Encountered error running function ${name}: ${JSON.stringify(error)}`);
|
||||
// We want to catch the error only when a promise occurs
|
||||
// if it occurs in the Observable, we cannot easily recover
|
||||
// from it because the function may have already emitted
|
||||
// values which could lead to an invalid conversation state,
|
||||
// so in that case we let the stream fail.
|
||||
return of(createServerSideFunctionResponseError({ name, error }));
|
||||
}),
|
||||
switchMap((response) => {
|
||||
if (isObservable(response)) {
|
||||
return response;
|
||||
}
|
||||
|
||||
// is messageAdd event
|
||||
if ('type' in response) {
|
||||
return of(response);
|
||||
}
|
||||
// is messageAdd event
|
||||
if ('type' in response) {
|
||||
return of(response);
|
||||
}
|
||||
|
||||
const encoded = encode(JSON.stringify(response.content || {}));
|
||||
const encoded = encode(JSON.stringify(response.content || {}));
|
||||
|
||||
const exceededTokenLimit = encoded.length >= MAX_FUNCTION_RESPONSE_TOKEN_COUNT;
|
||||
const exceededTokenLimit = encoded.length >= MAX_FUNCTION_RESPONSE_TOKEN_COUNT;
|
||||
|
||||
return of(
|
||||
createFunctionResponseMessage({
|
||||
name,
|
||||
content: exceededTokenLimit
|
||||
? {
|
||||
message:
|
||||
'Function response exceeded the maximum length allowed and was truncated',
|
||||
truncated: decode(take(encoded, MAX_FUNCTION_RESPONSE_TOKEN_COUNT)),
|
||||
}
|
||||
: response.content,
|
||||
data: response.data,
|
||||
})
|
||||
);
|
||||
})
|
||||
);
|
||||
})
|
||||
);
|
||||
return of(
|
||||
createFunctionResponseMessage({
|
||||
name,
|
||||
content: exceededTokenLimit
|
||||
? {
|
||||
message:
|
||||
'Function response exceeded the maximum length allowed and was truncated',
|
||||
truncated: decode(take(encoded, MAX_FUNCTION_RESPONSE_TOKEN_COUNT)),
|
||||
}
|
||||
: response.content,
|
||||
data: response.data,
|
||||
})
|
||||
);
|
||||
})
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
function getFunctionDefinitions({
|
||||
|
@ -315,7 +311,6 @@ export function continueConversation({
|
|||
return concat(
|
||||
shared$,
|
||||
shared$.pipe(
|
||||
withoutTokenCountEvents(),
|
||||
extractMessages(),
|
||||
switchMap((extractedMessages) => {
|
||||
if (!extractedMessages.length) {
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import { Observable, OperatorFunction, map } from 'rxjs';
|
||||
import { Observable, OperatorFunction, filter, map } from 'rxjs';
|
||||
import { v4 } from 'uuid';
|
||||
import {
|
||||
ChatCompletionEvent as InferenceChatCompletionEvent,
|
||||
|
@ -13,17 +13,17 @@ import {
|
|||
} from '@kbn/inference-common';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
TokenCountEvent,
|
||||
ChatCompletionMessageEvent,
|
||||
StreamingChatResponseEventType,
|
||||
} from '../../../../common';
|
||||
|
||||
export function convertInferenceEventsToStreamingEvents(): OperatorFunction<
|
||||
InferenceChatCompletionEvent,
|
||||
ChatCompletionChunkEvent | TokenCountEvent | ChatCompletionMessageEvent
|
||||
ChatCompletionChunkEvent | ChatCompletionMessageEvent
|
||||
> {
|
||||
return (events$: Observable<InferenceChatCompletionEvent>) => {
|
||||
return events$.pipe(
|
||||
filter((event) => event.type !== InferenceChatCompletionEventType.ChatCompletionTokenCount),
|
||||
map((event) => {
|
||||
switch (event.type) {
|
||||
case InferenceChatCompletionEventType.ChatCompletionChunk:
|
||||
|
@ -42,16 +42,6 @@ export function convertInferenceEventsToStreamingEvents(): OperatorFunction<
|
|||
: undefined,
|
||||
},
|
||||
} as ChatCompletionChunkEvent;
|
||||
case InferenceChatCompletionEventType.ChatCompletionTokenCount:
|
||||
// Convert to TokenCountEvent
|
||||
return {
|
||||
type: StreamingChatResponseEventType.TokenCount,
|
||||
tokens: {
|
||||
completion: event.tokens.completion,
|
||||
prompt: event.tokens.prompt,
|
||||
total: event.tokens.total,
|
||||
},
|
||||
} as TokenCountEvent;
|
||||
case InferenceChatCompletionEventType.ChatCompletionMessage:
|
||||
// Convert to ChatCompletionMessageEvent
|
||||
return {
|
||||
|
@ -68,6 +58,7 @@ export function convertInferenceEventsToStreamingEvents(): OperatorFunction<
|
|||
: undefined,
|
||||
},
|
||||
} as ChatCompletionMessageEvent;
|
||||
|
||||
default:
|
||||
throw new Error(`Unknown event type`);
|
||||
}
|
||||
|
|
|
@ -1,37 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { filter, OperatorFunction, scan, startWith } from 'rxjs';
|
||||
import {
|
||||
StreamingChatResponseEvent,
|
||||
StreamingChatResponseEventType,
|
||||
TokenCountEvent,
|
||||
} from '../../../../common/conversation_complete';
|
||||
|
||||
export function extractTokenCount(): OperatorFunction<
|
||||
StreamingChatResponseEvent,
|
||||
TokenCountEvent['tokens']
|
||||
> {
|
||||
return (events$) => {
|
||||
return events$.pipe(
|
||||
filter(
|
||||
(event): event is TokenCountEvent =>
|
||||
event.type === StreamingChatResponseEventType.TokenCount
|
||||
),
|
||||
scan(
|
||||
(acc, event) => {
|
||||
acc.completion += event.tokens.completion;
|
||||
acc.prompt += event.tokens.prompt;
|
||||
acc.total += event.tokens.total;
|
||||
return acc;
|
||||
},
|
||||
{ completion: 0, prompt: 0, total: 0 }
|
||||
),
|
||||
startWith({ completion: 0, prompt: 0, total: 0 })
|
||||
);
|
||||
};
|
||||
}
|
|
@ -9,7 +9,6 @@ import { ignoreElements, last, merge, Observable, shareReplay, tap } from 'rxjs'
|
|||
import { createFunctionNotFoundError, FunctionDefinition } from '../../../../common';
|
||||
import { ChatEvent } from '../../../../common/conversation_complete';
|
||||
import { concatenateChatCompletionChunks } from '../../../../common/utils/concatenate_chat_completion_chunks';
|
||||
import { withoutTokenCountEvents } from '../../../../common/utils/without_token_count_events';
|
||||
|
||||
export function failOnNonExistingFunctionCall({
|
||||
functions,
|
||||
|
@ -22,7 +21,6 @@ export function failOnNonExistingFunctionCall({
|
|||
return merge(
|
||||
shared$,
|
||||
shared$.pipe(
|
||||
withoutTokenCountEvents(),
|
||||
concatenateChatCompletionChunks(),
|
||||
last(),
|
||||
tap((event) => {
|
||||
|
|
|
@ -4,9 +4,9 @@
|
|||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
import { filter, lastValueFrom, of, throwError, toArray } from 'rxjs';
|
||||
import { filter, lastValueFrom, of, throwError } from 'rxjs';
|
||||
import { ChatCompleteResponse } from '@kbn/inference-common';
|
||||
import { Message, MessageRole, StreamingChatResponseEventType } from '../../../../common';
|
||||
import { Message, MessageRole } from '../../../../common';
|
||||
import { LangTracer } from '../instrumentation/lang_tracer';
|
||||
import { TITLE_CONVERSATION_FUNCTION_NAME, getGeneratedTitle } from './get_generated_title';
|
||||
|
||||
|
@ -118,44 +118,6 @@ describe('getGeneratedTitle', () => {
|
|||
expect(await testTitle(`"User's request for a title"`)).toEqual(`User's request for a title`);
|
||||
});
|
||||
|
||||
it('ignores token count events and still passes them through', async () => {
|
||||
const { title$ } = callGenerateTitle([
|
||||
{
|
||||
content: '',
|
||||
toolCalls: [
|
||||
{
|
||||
toolCallId: 'test_id',
|
||||
function: {
|
||||
name: 'title_conversation',
|
||||
arguments: {
|
||||
title: 'My title',
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
tokens: {
|
||||
completion: 10,
|
||||
prompt: 10,
|
||||
total: 10,
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
const events = await lastValueFrom(title$.pipe(toArray()));
|
||||
|
||||
expect(events).toEqual([
|
||||
'My title',
|
||||
{
|
||||
tokens: {
|
||||
completion: 10,
|
||||
prompt: 10,
|
||||
total: 10,
|
||||
},
|
||||
type: StreamingChatResponseEventType.TokenCount,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('handles errors in chat and falls back to the default title', async () => {
|
||||
const chatSpy = jest
|
||||
.fn()
|
||||
|
|
|
@ -9,8 +9,7 @@ import { catchError, mergeMap, Observable, of, tap, from } from 'rxjs';
|
|||
import { Logger } from '@kbn/logging';
|
||||
import { ChatCompleteResponse } from '@kbn/inference-common';
|
||||
import type { ObservabilityAIAssistantClient } from '..';
|
||||
import { Message, MessageRole, StreamingChatResponseEventType } from '../../../../common';
|
||||
import { TokenCountEvent } from '../../../../common/conversation_complete';
|
||||
import { Message, MessageRole } from '../../../../common';
|
||||
import { LangTracer } from '../instrumentation/lang_tracer';
|
||||
|
||||
export const TITLE_CONVERSATION_FUNCTION_NAME = 'title_conversation';
|
||||
|
@ -33,7 +32,7 @@ export function getGeneratedTitle({
|
|||
chat: ChatFunctionWithoutConnectorAndTokenCount;
|
||||
logger: Pick<Logger, 'debug' | 'error'>;
|
||||
tracer: LangTracer;
|
||||
}): Observable<string | TokenCountEvent> {
|
||||
}): Observable<string> {
|
||||
return from(
|
||||
chat('generate_title', {
|
||||
systemMessage:
|
||||
|
@ -84,21 +83,7 @@ export function getGeneratedTitle({
|
|||
// - JustTextWithoutQuotes => Captures: JustTextWithoutQuotes
|
||||
title = title.replace(/^"(.*)"$/g, '$1').replace(/^'(.*)'$/g, '$1');
|
||||
|
||||
const tokenCount: TokenCountEvent | undefined = response.tokens
|
||||
? {
|
||||
type: StreamingChatResponseEventType.TokenCount,
|
||||
tokens: {
|
||||
completion: response.tokens.completion,
|
||||
prompt: response.tokens.prompt,
|
||||
total: response.tokens.total,
|
||||
},
|
||||
}
|
||||
: undefined;
|
||||
|
||||
const events: Array<string | TokenCountEvent> = [title];
|
||||
if (tokenCount) events.push(tokenCount);
|
||||
|
||||
return from(events); // Emit each event separately
|
||||
return of(title);
|
||||
}),
|
||||
tap((event) => {
|
||||
if (typeof event === 'string') {
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { merge, Observable, partition } from 'rxjs';
|
||||
import type { StreamingChatResponseEvent } from '../../../../common';
|
||||
import {
|
||||
StreamingChatResponseEventType,
|
||||
TokenCountEvent,
|
||||
} from '../../../../common/conversation_complete';
|
||||
|
||||
type Hide = <T extends StreamingChatResponseEvent>() => (
|
||||
source$: Observable<T | TokenCountEvent>
|
||||
) => Observable<Exclude<T, TokenCountEvent>>;
|
||||
|
||||
export function hideTokenCountEvents<T>(
|
||||
cb: (hide: Hide) => Observable<Exclude<T, TokenCountEvent>>
|
||||
): Observable<T | TokenCountEvent> {
|
||||
// `hide` can be called multiple times, so we keep track of each invocation
|
||||
const allInterceptors: Array<Observable<TokenCountEvent>> = [];
|
||||
|
||||
const hide: Hide = () => (source$) => {
|
||||
const [tokenCountEvents$, otherEvents$] = partition(
|
||||
source$,
|
||||
(value): value is TokenCountEvent => value.type === StreamingChatResponseEventType.TokenCount
|
||||
);
|
||||
|
||||
allInterceptors.push(tokenCountEvents$);
|
||||
|
||||
return otherEvents$;
|
||||
};
|
||||
|
||||
// combine the two observables again
|
||||
return merge(cb(hide), ...allInterceptors);
|
||||
}
|
|
@ -25,10 +25,6 @@ const dynamic = {
|
|||
dynamic: true,
|
||||
};
|
||||
|
||||
const integer = {
|
||||
type: 'integer' as const,
|
||||
};
|
||||
|
||||
export const conversationComponentTemplate: ClusterComponentTemplate['component_template']['template'] =
|
||||
{
|
||||
mappings: {
|
||||
|
@ -59,13 +55,6 @@ export const conversationComponentTemplate: ClusterComponentTemplate['component_
|
|||
id: keyword,
|
||||
title: text,
|
||||
last_updated: date,
|
||||
token_count: {
|
||||
properties: {
|
||||
prompt: integer,
|
||||
completion: integer,
|
||||
total: integer,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
namespace: keyword,
|
||||
|
|
|
@ -11,7 +11,6 @@ import {
|
|||
BufferFlushEvent,
|
||||
StreamingChatResponseEventType,
|
||||
StreamingChatResponseEventWithoutError,
|
||||
TokenCountEvent,
|
||||
} from '../../../common/conversation_complete';
|
||||
|
||||
// The Cloud proxy currently buffers 4kb or 8kb of data until flushing.
|
||||
|
@ -19,7 +18,7 @@ import {
|
|||
// so we manually insert some data every 250ms if needed to force it
|
||||
// to flush.
|
||||
|
||||
export function flushBuffer<T extends StreamingChatResponseEventWithoutError | TokenCountEvent>(
|
||||
export function flushBuffer<T extends StreamingChatResponseEventWithoutError>(
|
||||
isCloud: boolean
|
||||
): OperatorFunction<T, T | BufferFlushEvent> {
|
||||
return (source: Observable<T>) =>
|
||||
|
|
|
@ -24,11 +24,10 @@ import {
|
|||
ChatCompletionChunkEvent,
|
||||
StreamingChatResponseEventType,
|
||||
StreamingChatResponseEventWithoutError,
|
||||
TokenCountEvent,
|
||||
} from '../../../common/conversation_complete';
|
||||
|
||||
export function observableIntoOpenAIStream(
|
||||
source: Observable<StreamingChatResponseEventWithoutError | BufferFlushEvent | TokenCountEvent>,
|
||||
source: Observable<StreamingChatResponseEventWithoutError | BufferFlushEvent>,
|
||||
logger: Logger
|
||||
) {
|
||||
const stream = new PassThrough();
|
||||
|
|
|
@ -110,8 +110,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
await simulator.next(`Part: ${i}\n`);
|
||||
}
|
||||
|
||||
await simulator.tokenCount({ completion: 20, prompt: 33, total: 53 });
|
||||
|
||||
await simulator.complete();
|
||||
|
||||
await new Promise<void>((innerResolve) => passThrough.on('end', () => innerResolve()));
|
||||
|
@ -127,15 +125,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
2
|
||||
)}`
|
||||
);
|
||||
|
||||
const tokenCountChunk = receivedChunks.find((chunk) => chunk.type === 'tokenCount');
|
||||
expect(tokenCountChunk).to.eql(
|
||||
{
|
||||
type: 'tokenCount',
|
||||
tokens: { completion: 20, prompt: 33, total: 53 },
|
||||
},
|
||||
`received token count chunk did not match expected`
|
||||
);
|
||||
}
|
||||
|
||||
runTest().then(resolve, reject);
|
||||
|
|
|
@ -12,7 +12,6 @@ import expect from '@kbn/expect';
|
|||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
ConversationCreateEvent,
|
||||
ConversationUpdateEvent,
|
||||
MessageAddEvent,
|
||||
StreamingChatResponseEvent,
|
||||
StreamingChatResponseEventType,
|
||||
|
@ -25,7 +24,7 @@ import {
|
|||
LlmResponseSimulator,
|
||||
} from '../../../../../../observability_ai_assistant_api_integration/common/create_llm_proxy';
|
||||
import { createOpenAiChunk } from '../../../../../../observability_ai_assistant_api_integration/common/create_openai_chunk';
|
||||
import { decodeEvents, getConversationCreatedEvent, getConversationUpdatedEvent } from '../helpers';
|
||||
import { decodeEvents, getConversationCreatedEvent } from '../helpers';
|
||||
import type { DeploymentAgnosticFtrProviderContext } from '../../../../ftr_provider_context';
|
||||
import { SupertestWithRoleScope } from '../../../../services/role_scoped_supertest';
|
||||
|
||||
|
@ -104,7 +103,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
},
|
||||
],
|
||||
});
|
||||
await titleSimulator.tokenCount({ completion: 5, prompt: 10, total: 15 });
|
||||
await titleSimulator.complete();
|
||||
|
||||
await conversationSimulator.status(200);
|
||||
|
@ -170,7 +168,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
|
||||
await simulator.rawWrite(`data: ${chunk.substring(0, 10)}`);
|
||||
await simulator.rawWrite(`${chunk.substring(10)}\n\n`);
|
||||
await simulator.tokenCount({ completion: 20, prompt: 33, total: 53 });
|
||||
await simulator.complete();
|
||||
|
||||
await new Promise<void>((resolve) => passThrough.on('end', () => resolve()));
|
||||
|
@ -253,7 +250,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
events = await getEvents({}, async (conversationSimulator) => {
|
||||
await conversationSimulator.next('Hello');
|
||||
await conversationSimulator.next(' again');
|
||||
await conversationSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
|
||||
await conversationSimulator.complete();
|
||||
}).then((_events) => {
|
||||
return _events.filter(
|
||||
|
@ -296,26 +292,12 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
},
|
||||
});
|
||||
|
||||
expect(
|
||||
omit(
|
||||
events[4],
|
||||
'conversation.id',
|
||||
'conversation.last_updated',
|
||||
'conversation.token_count'
|
||||
)
|
||||
).to.eql({
|
||||
expect(omit(events[4], 'conversation.id', 'conversation.last_updated')).to.eql({
|
||||
type: StreamingChatResponseEventType.ConversationCreate,
|
||||
conversation: {
|
||||
title: 'My generated title',
|
||||
},
|
||||
});
|
||||
|
||||
const tokenCount = (events[4] as ConversationCreateEvent).conversation.token_count!;
|
||||
|
||||
expect(tokenCount.completion).to.be.greaterThan(0);
|
||||
expect(tokenCount.prompt).to.be.greaterThan(0);
|
||||
|
||||
expect(tokenCount.total).to.eql(tokenCount.completion + tokenCount.prompt);
|
||||
});
|
||||
|
||||
after(async () => {
|
||||
|
@ -375,7 +357,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
},
|
||||
],
|
||||
});
|
||||
await conversationSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
|
||||
await conversationSimulator.complete();
|
||||
}
|
||||
);
|
||||
|
@ -422,7 +403,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
|
||||
describe('when updating an existing conversation', () => {
|
||||
let conversationCreatedEvent: ConversationCreateEvent;
|
||||
let conversationUpdatedEvent: ConversationUpdateEvent;
|
||||
|
||||
before(async () => {
|
||||
void proxy
|
||||
|
@ -499,8 +479,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
expect(updatedResponse.status).to.be(200);
|
||||
|
||||
await proxy.waitForAllInterceptorsSettled();
|
||||
|
||||
conversationUpdatedEvent = getConversationUpdatedEvent(updatedResponse.body);
|
||||
});
|
||||
|
||||
after(async () => {
|
||||
|
@ -515,18 +493,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
|
||||
expect(status).to.be(200);
|
||||
});
|
||||
|
||||
it('has correct token count for a new conversation', async () => {
|
||||
expect(conversationCreatedEvent.conversation.token_count?.completion).to.be.greaterThan(0);
|
||||
expect(conversationCreatedEvent.conversation.token_count?.prompt).to.be.greaterThan(0);
|
||||
expect(conversationCreatedEvent.conversation.token_count?.total).to.be.greaterThan(0);
|
||||
});
|
||||
|
||||
it('has correct token count for the updated conversation', async () => {
|
||||
expect(conversationUpdatedEvent.conversation.token_count!.total).to.be.greaterThan(
|
||||
conversationCreatedEvent.conversation.token_count!.total
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// todo
|
||||
|
|
|
@ -88,7 +88,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
|
||||
await titleSimulator.status(200);
|
||||
await titleSimulator.next('My generated title');
|
||||
await titleSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
|
||||
await titleSimulator.complete();
|
||||
|
||||
await conversationSimulator.status(200);
|
||||
|
@ -173,7 +172,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
},
|
||||
],
|
||||
});
|
||||
await conversationSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
|
||||
await conversationSimulator.complete();
|
||||
}
|
||||
);
|
||||
|
@ -245,7 +243,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
},
|
||||
],
|
||||
});
|
||||
await conversationSimulator.tokenCount({ completion: 0, prompt: 0, total: 0 });
|
||||
await conversationSimulator.complete();
|
||||
}
|
||||
);
|
||||
|
@ -262,7 +259,6 @@ export default function ApiTest({ getService }: DeploymentAgnosticFtrProviderCon
|
|||
before(async () => {
|
||||
responseBody = await getOpenAIResponse(async (conversationSimulator) => {
|
||||
await conversationSimulator.next('Hello');
|
||||
await conversationSimulator.tokenCount({ completion: 5, prompt: 10, total: 15 });
|
||||
await conversationSimulator.complete();
|
||||
});
|
||||
});
|
||||
|
|
|
@ -45,7 +45,6 @@ export interface LlmResponseSimulator {
|
|||
}>;
|
||||
}
|
||||
) => Promise<void>;
|
||||
tokenCount: (msg: { completion: number; prompt: number; total: number }) => Promise<void>;
|
||||
error: (error: any) => Promise<void>;
|
||||
complete: () => Promise<void>;
|
||||
rawWrite: (chunk: string) => Promise<void>;
|
||||
|
@ -166,17 +165,6 @@ export class LlmProxy {
|
|||
Connection: 'keep-alive',
|
||||
});
|
||||
}),
|
||||
tokenCount: (msg) => {
|
||||
const chunk = {
|
||||
object: 'chat.completion.chunk',
|
||||
usage: {
|
||||
completion_tokens: msg.completion,
|
||||
prompt_tokens: msg.prompt,
|
||||
total_tokens: msg.total,
|
||||
},
|
||||
};
|
||||
return write(`data: ${JSON.stringify(chunk)}\n\n`);
|
||||
},
|
||||
next: (msg) => {
|
||||
const chunk = createOpenAiChunk(msg);
|
||||
return write(`data: ${JSON.stringify(chunk)}\n\n`);
|
||||
|
@ -220,7 +208,7 @@ export class LlmProxy {
|
|||
for (const chunk of parsedChunks) {
|
||||
await simulator.next(chunk);
|
||||
}
|
||||
await simulator.tokenCount({ completion: 1, prompt: 1, total: 1 });
|
||||
|
||||
await simulator.complete();
|
||||
},
|
||||
} as any;
|
||||
|
|
|
@ -137,11 +137,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
|
|||
],
|
||||
conversation: {
|
||||
title: 'My old conversation',
|
||||
token_count: {
|
||||
completion: 1,
|
||||
prompt: 1,
|
||||
total: 2,
|
||||
},
|
||||
},
|
||||
'@timestamp': '2024-04-18T14:29:22.948',
|
||||
public: false,
|
||||
|
@ -278,14 +273,10 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
|
|||
],
|
||||
});
|
||||
|
||||
await titleSimulator.tokenCount({ completion: 1, prompt: 1, total: 2 });
|
||||
|
||||
await titleSimulator.complete();
|
||||
|
||||
await conversationSimulator.next('My response');
|
||||
|
||||
await conversationSimulator.tokenCount({ completion: 1, prompt: 1, total: 2 });
|
||||
|
||||
await conversationSimulator.complete();
|
||||
|
||||
await header.waitUntilLoadingHasFinished();
|
||||
|
@ -350,8 +341,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
|
|||
|
||||
await conversationSimulator.next('My second response');
|
||||
|
||||
await conversationSimulator.tokenCount({ completion: 1, prompt: 1, total: 2 });
|
||||
|
||||
await conversationSimulator.complete();
|
||||
|
||||
await header.waitUntilLoadingHasFinished();
|
||||
|
@ -445,8 +434,6 @@ export default function ApiTest({ getService, getPageObjects }: FtrProviderConte
|
|||
'Service Level Indicators (SLIs) are quantifiable defined metrics that measure the performance and availability of a service or distributed system.'
|
||||
);
|
||||
|
||||
await conversationSimulator.tokenCount({ completion: 1, prompt: 1, total: 2 });
|
||||
|
||||
await conversationSimulator.complete();
|
||||
|
||||
await header.waitUntilLoadingHasFinished();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue