mirror of
https://github.com/elastic/kibana.git
synced 2025-04-23 09:19:04 -04:00
[inference] openAI: fallback to manual token count when not provided in response (#207722)
## Summary Fix https://github.com/elastic/kibana/issues/207719 For openAI providers not emitting token usage metadata for the stream API, manually count tokens, so that a tokenCount event is always emitted.
This commit is contained in:
parent
4073aff617
commit
c087c984ff
15 changed files with 729 additions and 42 deletions
|
@ -8,10 +8,15 @@
|
|||
import OpenAI from 'openai';
|
||||
import { v4 } from 'uuid';
|
||||
import { PassThrough } from 'stream';
|
||||
import { lastValueFrom, Subject, toArray } from 'rxjs';
|
||||
import { lastValueFrom, Subject, toArray, filter } from 'rxjs';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import { ChatCompletionEventType, MessageRole } from '@kbn/inference-common';
|
||||
import {
|
||||
ChatCompletionEventType,
|
||||
MessageRole,
|
||||
isChatCompletionChunkEvent,
|
||||
isChatCompletionTokenCountEvent,
|
||||
} from '@kbn/inference-common';
|
||||
import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream';
|
||||
import { InferenceExecutor } from '../../utils/inference_executor';
|
||||
import { inferenceAdapter } from './inference_adapter';
|
||||
|
@ -110,7 +115,9 @@ describe('inferenceAdapter', () => {
|
|||
|
||||
source$.complete();
|
||||
|
||||
const allChunks = await lastValueFrom(response$.pipe(toArray()));
|
||||
const allChunks = await lastValueFrom(
|
||||
response$.pipe(filter(isChatCompletionChunkEvent), toArray())
|
||||
);
|
||||
|
||||
expect(allChunks).toEqual([
|
||||
{
|
||||
|
@ -126,6 +133,105 @@ describe('inferenceAdapter', () => {
|
|||
]);
|
||||
});
|
||||
|
||||
it('emits token count event when provided by the response', async () => {
|
||||
const source$ = new Subject<Record<string, any>>();
|
||||
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: observableIntoEventSourceStream(source$, logger),
|
||||
};
|
||||
});
|
||||
|
||||
const response$ = inferenceAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'Hello',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'First',
|
||||
},
|
||||
usage: {
|
||||
completion_tokens: 5,
|
||||
prompt_tokens: 10,
|
||||
total_tokens: 15,
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.complete();
|
||||
|
||||
const tokenChunks = await lastValueFrom(
|
||||
response$.pipe(filter(isChatCompletionTokenCountEvent), toArray())
|
||||
);
|
||||
|
||||
expect(tokenChunks).toEqual([
|
||||
{
|
||||
type: ChatCompletionEventType.ChatCompletionTokenCount,
|
||||
tokens: {
|
||||
completion: 5,
|
||||
prompt: 10,
|
||||
total: 15,
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('emits token count event when not provided by the response', async () => {
|
||||
const source$ = new Subject<Record<string, any>>();
|
||||
|
||||
executorMock.invoke.mockImplementation(async () => {
|
||||
return {
|
||||
actionId: '',
|
||||
status: 'ok',
|
||||
data: observableIntoEventSourceStream(source$, logger),
|
||||
};
|
||||
});
|
||||
|
||||
const response$ = inferenceAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'Hello',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'First',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.complete();
|
||||
|
||||
const tokenChunks = await lastValueFrom(
|
||||
response$.pipe(filter(isChatCompletionTokenCountEvent), toArray())
|
||||
);
|
||||
|
||||
expect(tokenChunks).toEqual([
|
||||
{
|
||||
type: ChatCompletionEventType.ChatCompletionTokenCount,
|
||||
tokens: {
|
||||
completion: expect.any(Number),
|
||||
prompt: expect.any(Number),
|
||||
total: expect.any(Number),
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('propagates the abort signal when provided', () => {
|
||||
const abortController = new AbortController();
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import {
|
|||
toolChoiceToOpenAI,
|
||||
messagesToOpenAI,
|
||||
processOpenAIStream,
|
||||
emitTokenCountEstimateIfMissing,
|
||||
} from '../openai';
|
||||
|
||||
export const inferenceAdapter: InferenceConnectorAdapter = {
|
||||
|
@ -85,6 +86,7 @@ export const inferenceAdapter: InferenceConnectorAdapter = {
|
|||
);
|
||||
}),
|
||||
processOpenAIStream(),
|
||||
emitTokenCountEstimateIfMissing({ request }),
|
||||
simulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity
|
||||
);
|
||||
},
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { of, toArray, lastValueFrom } from 'rxjs';
|
||||
import { chunkEvent, tokensEvent } from '../../../test_utils';
|
||||
import type { OpenAIRequest } from './types';
|
||||
import { emitTokenCountEstimateIfMissing } from './emit_token_count_if_missing';
|
||||
|
||||
jest.mock('./manually_count_tokens');
|
||||
import { manuallyCountPromptTokens, manuallyCountCompletionTokens } from './manually_count_tokens';
|
||||
const manuallyCountPromptTokensMock = manuallyCountPromptTokens as jest.MockedFn<
|
||||
typeof manuallyCountPromptTokens
|
||||
>;
|
||||
const manuallyCountCompletionTokensMock = manuallyCountCompletionTokens as jest.MockedFn<
|
||||
typeof manuallyCountCompletionTokens
|
||||
>;
|
||||
|
||||
const stubRequest = (content: string = 'foo'): OpenAIRequest => {
|
||||
return {
|
||||
messages: [{ role: 'user', content }],
|
||||
};
|
||||
};
|
||||
|
||||
describe('emitTokenCountEstimateIfMissing', () => {
|
||||
beforeEach(() => {
|
||||
manuallyCountPromptTokensMock.mockReset();
|
||||
manuallyCountCompletionTokensMock.mockReset();
|
||||
});
|
||||
|
||||
it('mirrors the source when token count is emitted', async () => {
|
||||
const events = [
|
||||
chunkEvent('chunk-1'),
|
||||
chunkEvent('chunk-2'),
|
||||
chunkEvent('chunk-3'),
|
||||
tokensEvent({ completion: 5, prompt: 10, total: 15 }),
|
||||
];
|
||||
|
||||
const result$ = of(...events).pipe(emitTokenCountEstimateIfMissing({ request: stubRequest() }));
|
||||
const output = await lastValueFrom(result$.pipe(toArray()));
|
||||
|
||||
expect(output).toEqual(events);
|
||||
|
||||
expect(manuallyCountPromptTokensMock).not.toHaveBeenCalled();
|
||||
expect(manuallyCountCompletionTokensMock).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('emits a tokenCount event if the source completes without emitting one', async () => {
|
||||
manuallyCountPromptTokensMock.mockReturnValue(5);
|
||||
manuallyCountCompletionTokensMock.mockReturnValue(10);
|
||||
|
||||
const events = [chunkEvent('chunk-1'), chunkEvent('chunk-2'), chunkEvent('chunk-3')];
|
||||
|
||||
const result$ = of(...events).pipe(emitTokenCountEstimateIfMissing({ request: stubRequest() }));
|
||||
const output = await lastValueFrom(result$.pipe(toArray()));
|
||||
|
||||
expect(manuallyCountPromptTokensMock).toHaveBeenCalledTimes(1);
|
||||
expect(manuallyCountCompletionTokensMock).toHaveBeenCalledTimes(1);
|
||||
|
||||
expect(output).toEqual([...events, tokensEvent({ prompt: 5, completion: 10, total: 15 })]);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,74 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { OperatorFunction, Observable } from 'rxjs';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionEventType,
|
||||
ChatCompletionTokenCountEvent,
|
||||
isChatCompletionTokenCountEvent,
|
||||
isChatCompletionChunkEvent,
|
||||
} from '@kbn/inference-common';
|
||||
import type { OpenAIRequest } from './types';
|
||||
import { manuallyCountPromptTokens, manuallyCountCompletionTokens } from './manually_count_tokens';
|
||||
|
||||
/**
|
||||
* Operator mirroring the source and then emitting a tokenCount event when the source completes,
|
||||
* if and only if the source did not emit a tokenCount event itself.
|
||||
*
|
||||
* This is used to manually count tokens and emit the associated event for
|
||||
* providers that don't support sending token counts for the stream API.
|
||||
*
|
||||
* @param request the OpenAI request that was sent to the connector.
|
||||
*/
|
||||
export function emitTokenCountEstimateIfMissing<
|
||||
T extends ChatCompletionChunkEvent | ChatCompletionTokenCountEvent
|
||||
>({ request }: { request: OpenAIRequest }): OperatorFunction<T, T | ChatCompletionTokenCountEvent> {
|
||||
return (source$) => {
|
||||
let tokenCountEmitted = false;
|
||||
const chunks: ChatCompletionChunkEvent[] = [];
|
||||
|
||||
return new Observable<T | ChatCompletionTokenCountEvent>((subscriber) => {
|
||||
return source$.subscribe({
|
||||
next: (value) => {
|
||||
if (isChatCompletionTokenCountEvent(value)) {
|
||||
tokenCountEmitted = true;
|
||||
} else if (isChatCompletionChunkEvent(value)) {
|
||||
chunks.push(value);
|
||||
}
|
||||
subscriber.next(value);
|
||||
},
|
||||
error: (err) => {
|
||||
subscriber.error(err);
|
||||
},
|
||||
complete: () => {
|
||||
if (!tokenCountEmitted) {
|
||||
subscriber.next(manuallyCountTokens(request, chunks));
|
||||
}
|
||||
subscriber.complete();
|
||||
},
|
||||
});
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
export function manuallyCountTokens(
|
||||
request: OpenAIRequest,
|
||||
chunks: ChatCompletionChunkEvent[]
|
||||
): ChatCompletionTokenCountEvent {
|
||||
const promptTokens = manuallyCountPromptTokens(request);
|
||||
const completionTokens = manuallyCountCompletionTokens(chunks);
|
||||
|
||||
return {
|
||||
type: ChatCompletionEventType.ChatCompletionTokenCount,
|
||||
tokens: {
|
||||
prompt: promptTokens,
|
||||
completion: completionTokens,
|
||||
total: promptTokens + completionTokens,
|
||||
},
|
||||
};
|
||||
}
|
|
@ -8,3 +8,4 @@
|
|||
export { openAIAdapter } from './openai_adapter';
|
||||
export { toolChoiceToOpenAI, messagesToOpenAI, toolsToOpenAI } from './to_openai';
|
||||
export { processOpenAIStream } from './process_openai_stream';
|
||||
export { emitTokenCountEstimateIfMissing } from './emit_token_count_if_missing';
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { chunkEvent } from '../../../test_utils';
|
||||
import { manuallyCountPromptTokens, manuallyCountCompletionTokens } from './manually_count_tokens';
|
||||
|
||||
describe('manuallyCountPromptTokens', () => {
|
||||
const reference = manuallyCountPromptTokens({
|
||||
messages: [{ role: 'user', content: 'message' }],
|
||||
});
|
||||
|
||||
it('counts token from the message content', () => {
|
||||
const count = manuallyCountPromptTokens({
|
||||
messages: [
|
||||
{ role: 'user', content: 'question 1' },
|
||||
{ role: 'assistant', content: 'answer 1' },
|
||||
{ role: 'user', content: 'question 2' },
|
||||
],
|
||||
});
|
||||
|
||||
expect(count).toBeGreaterThan(reference);
|
||||
});
|
||||
|
||||
it('counts token from tools', () => {
|
||||
const count = manuallyCountPromptTokens({
|
||||
messages: [{ role: 'user', content: 'message' }],
|
||||
tools: [{ type: 'function', function: { name: 'my-function', description: 'description' } }],
|
||||
});
|
||||
|
||||
expect(count).toBeGreaterThan(reference);
|
||||
});
|
||||
});
|
||||
|
||||
describe('manuallyCountCompletionTokens', () => {
|
||||
const reference = manuallyCountCompletionTokens([chunkEvent('chunk-1')]);
|
||||
|
||||
it('counts tokens from the content chunks', () => {
|
||||
const count = manuallyCountCompletionTokens([
|
||||
chunkEvent('chunk-1'),
|
||||
chunkEvent('chunk-2'),
|
||||
chunkEvent('chunk-2'),
|
||||
]);
|
||||
|
||||
expect(count).toBeGreaterThan(reference);
|
||||
});
|
||||
|
||||
it('counts tokens from chunks with tool calls', () => {
|
||||
const count = manuallyCountCompletionTokens([
|
||||
chunkEvent('chunk-1', [
|
||||
{
|
||||
toolCallId: 'tool-call-id',
|
||||
index: 0,
|
||||
function: {
|
||||
name: 'function',
|
||||
arguments: '{}',
|
||||
},
|
||||
},
|
||||
]),
|
||||
]);
|
||||
|
||||
expect(count).toBeGreaterThan(reference);
|
||||
});
|
||||
});
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { encode } from 'gpt-tokenizer';
|
||||
import { ChatCompletionChunkEvent } from '@kbn/inference-common';
|
||||
import type { OpenAIRequest } from './types';
|
||||
import { mergeChunks } from '../../utils';
|
||||
|
||||
export const manuallyCountPromptTokens = (request: OpenAIRequest) => {
|
||||
// per https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
const tokensFromMessages = encode(
|
||||
request.messages
|
||||
.map(
|
||||
(msg) =>
|
||||
`<|start|>${msg.role}\n${msg.content}\n${
|
||||
'name' in msg
|
||||
? msg.name
|
||||
: 'function_call' in msg && msg.function_call
|
||||
? msg.function_call.name + '\n' + msg.function_call.arguments
|
||||
: ''
|
||||
}<|end|>`
|
||||
)
|
||||
.join('\n')
|
||||
).length;
|
||||
|
||||
// this is an approximation. OpenAI cuts off a function schema
|
||||
// at a certain level of nesting, so their token count might
|
||||
// be lower than what we are calculating here.
|
||||
const tokensFromFunctions = request.tools
|
||||
? encode(
|
||||
request.tools
|
||||
?.map(({ function: fn }) => {
|
||||
return `${fn.name}:${fn.description}:${JSON.stringify(fn.parameters)}`;
|
||||
})
|
||||
.join('\n')
|
||||
).length
|
||||
: 0;
|
||||
|
||||
return tokensFromMessages + tokensFromFunctions;
|
||||
};
|
||||
|
||||
export const manuallyCountCompletionTokens = (chunks: ChatCompletionChunkEvent[]) => {
|
||||
const message = mergeChunks(chunks);
|
||||
|
||||
const tokenFromContent = encode(message.content).length;
|
||||
|
||||
const tokenFromToolCalls = message.tool_calls?.length
|
||||
? encode(
|
||||
message.tool_calls
|
||||
.map((toolCall) => {
|
||||
return JSON.stringify(toolCall);
|
||||
})
|
||||
.join('\n')
|
||||
).length
|
||||
: 0;
|
||||
|
||||
return tokenFromContent + tokenFromToolCalls;
|
||||
};
|
|
@ -9,10 +9,14 @@ import OpenAI from 'openai';
|
|||
import { v4 } from 'uuid';
|
||||
import { PassThrough } from 'stream';
|
||||
import { pick } from 'lodash';
|
||||
import { lastValueFrom, Subject, toArray } from 'rxjs';
|
||||
import { lastValueFrom, Subject, toArray, filter } from 'rxjs';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import { loggerMock } from '@kbn/logging-mocks';
|
||||
import { ChatCompletionEventType, MessageRole } from '@kbn/inference-common';
|
||||
import {
|
||||
ChatCompletionEventType,
|
||||
isChatCompletionChunkEvent,
|
||||
MessageRole,
|
||||
} from '@kbn/inference-common';
|
||||
import { observableIntoEventSourceStream } from '../../../util/observable_into_event_source_stream';
|
||||
import { InferenceExecutor } from '../../utils/inference_executor';
|
||||
import { openAIAdapter } from './openai_adapter';
|
||||
|
@ -448,7 +452,9 @@ describe('openAIAdapter', () => {
|
|||
|
||||
source$.complete();
|
||||
|
||||
const allChunks = await lastValueFrom(response$.pipe(toArray()));
|
||||
const allChunks = await lastValueFrom(
|
||||
response$.pipe(filter(isChatCompletionChunkEvent), toArray())
|
||||
);
|
||||
|
||||
expect(allChunks).toEqual([
|
||||
{
|
||||
|
@ -502,7 +508,9 @@ describe('openAIAdapter', () => {
|
|||
|
||||
source$.complete();
|
||||
|
||||
const allChunks = await lastValueFrom(response$.pipe(toArray()));
|
||||
const allChunks = await lastValueFrom(
|
||||
response$.pipe(filter(isChatCompletionChunkEvent), toArray())
|
||||
);
|
||||
|
||||
expect(allChunks).toEqual([
|
||||
{
|
||||
|
@ -576,5 +584,45 @@ describe('openAIAdapter', () => {
|
|||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('emits token count event when not provided by the response', async () => {
|
||||
const response$ = openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'Hello',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'chunk',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.complete();
|
||||
|
||||
const allChunks = await lastValueFrom(response$.pipe(toArray()));
|
||||
|
||||
expect(allChunks).toEqual([
|
||||
{
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
content: 'chunk',
|
||||
tool_calls: [],
|
||||
},
|
||||
{
|
||||
type: ChatCompletionEventType.ChatCompletionTokenCount,
|
||||
tokens: {
|
||||
completion: expect.any(Number),
|
||||
prompt: expect.any(Number),
|
||||
total: expect.any(Number),
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import type OpenAI from 'openai';
|
||||
import { from, identity, switchMap, throwError } from 'rxjs';
|
||||
import { isReadable, Readable } from 'stream';
|
||||
import { createInferenceInternalError } from '@kbn/inference-common';
|
||||
|
@ -15,8 +14,10 @@ import {
|
|||
parseInlineFunctionCalls,
|
||||
wrapWithSimulatedFunctionCalling,
|
||||
} from '../../simulated_function_calling';
|
||||
import type { OpenAIRequest } from './types';
|
||||
import { messagesToOpenAI, toolsToOpenAI, toolChoiceToOpenAI } from './to_openai';
|
||||
import { processOpenAIStream } from './process_openai_stream';
|
||||
import { emitTokenCountEstimateIfMissing } from './emit_token_count_if_missing';
|
||||
|
||||
export const openAIAdapter: InferenceConnectorAdapter = {
|
||||
chatComplete: ({
|
||||
|
@ -33,7 +34,7 @@ export const openAIAdapter: InferenceConnectorAdapter = {
|
|||
}) => {
|
||||
const simulatedFunctionCalling = functionCalling === 'simulated';
|
||||
|
||||
let request: Omit<OpenAI.ChatCompletionCreateParams, 'model'> & { model?: string };
|
||||
let request: OpenAIRequest;
|
||||
if (simulatedFunctionCalling) {
|
||||
const wrapped = wrapWithSimulatedFunctionCalling({
|
||||
system,
|
||||
|
@ -84,6 +85,7 @@ export const openAIAdapter: InferenceConnectorAdapter = {
|
|||
);
|
||||
}),
|
||||
processOpenAIStream(),
|
||||
emitTokenCountEstimateIfMissing({ request }),
|
||||
simulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity
|
||||
);
|
||||
},
|
||||
|
|
|
@ -0,0 +1,10 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import type OpenAI from 'openai';
|
||||
|
||||
export type OpenAIRequest = Omit<OpenAI.ChatCompletionCreateParams, 'model'> & { model?: string };
|
|
@ -11,12 +11,12 @@ import {
|
|||
ChatCompletionMessageEvent,
|
||||
ChatCompletionTokenCountEvent,
|
||||
ToolOptions,
|
||||
UnvalidatedToolCall,
|
||||
withoutTokenCountEvents,
|
||||
} from '@kbn/inference-common';
|
||||
import type { Logger } from '@kbn/logging';
|
||||
import { OperatorFunction, map, merge, share, toArray } from 'rxjs';
|
||||
import { validateToolCalls } from '../../util/validate_tool_calls';
|
||||
import { mergeChunks } from './merge_chunks';
|
||||
|
||||
export function chunksIntoMessage<TToolOptions extends ToolOptions>({
|
||||
logger,
|
||||
|
@ -39,36 +39,7 @@ export function chunksIntoMessage<TToolOptions extends ToolOptions>({
|
|||
withoutTokenCountEvents(),
|
||||
toArray(),
|
||||
map((chunks): ChatCompletionMessageEvent<TToolOptions> => {
|
||||
const concatenatedChunk = chunks.reduce(
|
||||
(prev, chunk) => {
|
||||
prev.content += chunk.content ?? '';
|
||||
|
||||
chunk.tool_calls?.forEach((toolCall) => {
|
||||
let prevToolCall = prev.tool_calls[toolCall.index];
|
||||
if (!prevToolCall) {
|
||||
prev.tool_calls[toolCall.index] = {
|
||||
function: {
|
||||
name: '',
|
||||
arguments: '',
|
||||
},
|
||||
toolCallId: '',
|
||||
};
|
||||
|
||||
prevToolCall = prev.tool_calls[toolCall.index];
|
||||
}
|
||||
|
||||
prevToolCall.function.name += toolCall.function.name;
|
||||
prevToolCall.function.arguments += toolCall.function.arguments;
|
||||
prevToolCall.toolCallId += toolCall.toolCallId;
|
||||
});
|
||||
|
||||
return prev;
|
||||
},
|
||||
{ content: '', tool_calls: [] as UnvalidatedToolCall[] }
|
||||
);
|
||||
|
||||
// some models (Claude not to name it) can have their toolCall index not start at 0, so we remove the null elements
|
||||
concatenatedChunk.tool_calls = concatenatedChunk.tool_calls.filter((call) => !!call);
|
||||
const concatenatedChunk = mergeChunks(chunks);
|
||||
|
||||
logger.debug(() => `Received completed message: ${JSON.stringify(concatenatedChunk)}`);
|
||||
|
||||
|
|
|
@ -14,3 +14,4 @@ export {
|
|||
export { chunksIntoMessage } from './chunks_into_message';
|
||||
export { streamToResponse } from './stream_to_response';
|
||||
export { handleCancellation } from './handle_cancellation';
|
||||
export { mergeChunks } from './merge_chunks';
|
||||
|
|
|
@ -0,0 +1,223 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { ChatCompletionEventType } from '@kbn/inference-common';
|
||||
import { mergeChunks } from './merge_chunks';
|
||||
|
||||
describe('mergeChunks', () => {
|
||||
it('concatenates content chunks into a single message', async () => {
|
||||
const message = mergeChunks([
|
||||
{
|
||||
content: 'Hey',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [],
|
||||
},
|
||||
{
|
||||
content: ' how is it',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [],
|
||||
},
|
||||
{
|
||||
content: ' going',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [],
|
||||
},
|
||||
]);
|
||||
|
||||
expect(message).toEqual({
|
||||
content: 'Hey how is it going',
|
||||
tool_calls: [],
|
||||
});
|
||||
});
|
||||
|
||||
it('concatenates tool calls', async () => {
|
||||
const message = mergeChunks([
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: '',
|
||||
},
|
||||
index: 0,
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: '',
|
||||
arguments: '{ ',
|
||||
},
|
||||
index: 0,
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: '',
|
||||
arguments: '"foo": "bar" }',
|
||||
},
|
||||
index: 0,
|
||||
toolCallId: '1',
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
expect(message).toEqual({
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: '{ "foo": "bar" }',
|
||||
},
|
||||
toolCallId: '001',
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('concatenates tool calls even when the index does not start at 0', async () => {
|
||||
const message = mergeChunks([
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: '',
|
||||
},
|
||||
index: 1,
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: '',
|
||||
arguments: '{ ',
|
||||
},
|
||||
index: 1,
|
||||
toolCallId: '0',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: '',
|
||||
arguments: '"foo": "bar" }',
|
||||
},
|
||||
index: 1,
|
||||
toolCallId: '1',
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
expect(message).toEqual({
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: '{ "foo": "bar" }',
|
||||
},
|
||||
toolCallId: '001',
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('concatenates multiple tool calls into a single message', async () => {
|
||||
const message = mergeChunks([
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: '',
|
||||
},
|
||||
index: 0,
|
||||
toolCallId: '001',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: '',
|
||||
arguments: '{"foo": "bar"}',
|
||||
},
|
||||
index: 0,
|
||||
toolCallId: '',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
content: '',
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: '{ "foo": "baz" }',
|
||||
},
|
||||
index: 1,
|
||||
toolCallId: '002',
|
||||
},
|
||||
],
|
||||
},
|
||||
]);
|
||||
|
||||
expect(message).toEqual({
|
||||
content: '',
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: '{"foo": "bar"}',
|
||||
},
|
||||
toolCallId: '001',
|
||||
},
|
||||
{
|
||||
function: {
|
||||
name: 'myFunction',
|
||||
arguments: '{ "foo": "baz" }',
|
||||
},
|
||||
toolCallId: '002',
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
});
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { ChatCompletionChunkEvent, UnvalidatedToolCall } from '@kbn/inference-common';
|
||||
|
||||
interface UnvalidatedMessage {
|
||||
content: string;
|
||||
tool_calls: UnvalidatedToolCall[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Merges chunks into a message, concatenating the content and tool calls.
|
||||
*/
|
||||
export const mergeChunks = (chunks: ChatCompletionChunkEvent[]): UnvalidatedMessage => {
|
||||
const message = chunks.reduce<UnvalidatedMessage>(
|
||||
(prev, chunk) => {
|
||||
prev.content += chunk.content ?? '';
|
||||
|
||||
chunk.tool_calls?.forEach((toolCall) => {
|
||||
let prevToolCall = prev.tool_calls[toolCall.index];
|
||||
if (!prevToolCall) {
|
||||
prev.tool_calls[toolCall.index] = {
|
||||
function: {
|
||||
name: '',
|
||||
arguments: '',
|
||||
},
|
||||
toolCallId: '',
|
||||
};
|
||||
|
||||
prevToolCall = prev.tool_calls[toolCall.index];
|
||||
}
|
||||
|
||||
prevToolCall.function.name += toolCall.function.name;
|
||||
prevToolCall.function.arguments += toolCall.function.arguments;
|
||||
prevToolCall.toolCallId += toolCall.toolCallId;
|
||||
});
|
||||
|
||||
return prev;
|
||||
},
|
||||
{ content: '', tool_calls: [] }
|
||||
);
|
||||
|
||||
// some models (Claude not to name it) can have their toolCall index not start at 0, so we remove the null elements
|
||||
message.tool_calls = message.tool_calls.filter((call) => !!call);
|
||||
|
||||
return message;
|
||||
};
|
|
@ -12,12 +12,16 @@ import {
|
|||
ChatCompletionMessageEvent,
|
||||
ChatCompletionTokenCount,
|
||||
ToolCall,
|
||||
ChatCompletionChunkToolCall,
|
||||
} from '@kbn/inference-common';
|
||||
|
||||
export const chunkEvent = (content: string = 'chunk'): ChatCompletionChunkEvent => ({
|
||||
export const chunkEvent = (
|
||||
content: string = 'chunk',
|
||||
toolCalls: ChatCompletionChunkToolCall[] = []
|
||||
): ChatCompletionChunkEvent => ({
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
content,
|
||||
tool_calls: [],
|
||||
tool_calls: toolCalls,
|
||||
});
|
||||
|
||||
export const messageEvent = (
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue