mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 01:38:56 -04:00
# Backport This will backport the following commits from `main` to `8.x`: - [[inference] add support for openAI native stream token count (#200745)](https://github.com/elastic/kibana/pull/200745) <!--- Backport version: 9.4.3 --> ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) <!--BACKPORT [{"author":{"name":"Pierre Gayvallet","email":"pierre.gayvallet@elastic.co"},"sourceCommit":{"committedDate":"2024-11-20T16:53:44Z","message":"[inference] add support for openAI native stream token count (#200745)\n\n## Summary\r\n\r\nFix https://github.com/elastic/kibana/issues/192962\r\n\r\nAdd support for native openAI token count for streaming APIs.\r\n\r\nThis is done by adding the `stream_options: {\"include_usage\": true}`\r\nparameter when `stream: true` is being used\r\n([doc](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options)),\r\nand then using the `usage` entry for the last emitted chunk.\r\n\r\n**Note**: this was done only for the `OpenAI` and `AzureAI`\r\n[providers](83a701e837/x-pack/plugins/stack_connectors/common/openai/constants.ts (L27-L31)
),\r\nand **not** for the `Other` provider. The reasoning is that not all\r\nopenAI \"\"\"compatible\"\"\" providers fully support all options, so I didn't\r\nwant to risk adding a parameter that could cause some models using an\r\nopenAI adapter to reject the requests. This is also the reason why I did\r\nnot change the way\r\n[getTokenCountFromOpenAIStream](8bffd61805/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts (L15)
)\r\nfunction, as we want that to work for all providers.\r\n\r\n---------\r\n\r\nCo-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>","sha":"67171e15c2bd9063059701c4974f76f480ccd538","branchLabelMapping":{"^v9.0.0$":"main","^v8.17.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v9.0.0","backport:prev-minor","Team:AI Infra"],"title":"[inference] add support for openAI native stream token count","number":200745,"url":"https://github.com/elastic/kibana/pull/200745","mergeCommit":{"message":"[inference] add support for openAI native stream token count (#200745)\n\n## Summary\r\n\r\nFix https://github.com/elastic/kibana/issues/192962\r\n\r\nAdd support for native openAI token count for streaming APIs.\r\n\r\nThis is done by adding the `stream_options: {\"include_usage\": true}`\r\nparameter when `stream: true` is being used\r\n([doc](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options)),\r\nand then using the `usage` entry for the last emitted chunk.\r\n\r\n**Note**: this was done only for the `OpenAI` and `AzureAI`\r\n[providers](83a701e837/x-pack/plugins/stack_connectors/common/openai/constants.ts (L27-L31)
),\r\nand **not** for the `Other` provider. The reasoning is that not all\r\nopenAI \"\"\"compatible\"\"\" providers fully support all options, so I didn't\r\nwant to risk adding a parameter that could cause some models using an\r\nopenAI adapter to reject the requests. This is also the reason why I did\r\nnot change the way\r\n[getTokenCountFromOpenAIStream](8bffd61805/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts (L15)
)\r\nfunction, as we want that to work for all providers.\r\n\r\n---------\r\n\r\nCo-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>","sha":"67171e15c2bd9063059701c4974f76f480ccd538"}},"sourceBranch":"main","suggestedTargetBranches":[],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/200745","number":200745,"mergeCommit":{"message":"[inference] add support for openAI native stream token count (#200745)\n\n## Summary\r\n\r\nFix https://github.com/elastic/kibana/issues/192962\r\n\r\nAdd support for native openAI token count for streaming APIs.\r\n\r\nThis is done by adding the `stream_options: {\"include_usage\": true}`\r\nparameter when `stream: true` is being used\r\n([doc](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stream_options)),\r\nand then using the `usage` entry for the last emitted chunk.\r\n\r\n**Note**: this was done only for the `OpenAI` and `AzureAI`\r\n[providers](83a701e837/x-pack/plugins/stack_connectors/common/openai/constants.ts (L27-L31)
),\r\nand **not** for the `Other` provider. The reasoning is that not all\r\nopenAI \"\"\"compatible\"\"\" providers fully support all options, so I didn't\r\nwant to risk adding a parameter that could cause some models using an\r\nopenAI adapter to reject the requests. This is also the reason why I did\r\nnot change the way\r\n[getTokenCountFromOpenAIStream](8bffd61805/x-pack/plugins/actions/server/lib/get_token_count_from_openai_stream.ts (L15)
)\r\nfunction, as we want that to work for all providers.\r\n\r\n---------\r\n\r\nCo-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>","sha":"67171e15c2bd9063059701c4974f76f480ccd538"}}]}] BACKPORT--> Co-authored-by: Pierre Gayvallet <pierre.gayvallet@elastic.co>
This commit is contained in:
parent
43d4730ebd
commit
63f1de7e1b
9 changed files with 401 additions and 144 deletions
|
@ -61,6 +61,16 @@ describe('getTokenCountFromOpenAIStream', () => {
|
|||
],
|
||||
};
|
||||
|
||||
const usageChunk = {
|
||||
object: 'chat.completion.chunk',
|
||||
choices: [],
|
||||
usage: {
|
||||
prompt_tokens: 50,
|
||||
completion_tokens: 100,
|
||||
total_tokens: 150,
|
||||
},
|
||||
};
|
||||
|
||||
const PROMPT_TOKEN_COUNT = 36;
|
||||
const COMPLETION_TOKEN_COUNT = 5;
|
||||
|
||||
|
@ -70,55 +80,79 @@ describe('getTokenCountFromOpenAIStream', () => {
|
|||
});
|
||||
|
||||
describe('when a stream completes', () => {
|
||||
beforeEach(async () => {
|
||||
stream.write('data: [DONE]');
|
||||
stream.complete();
|
||||
});
|
||||
describe('with usage chunk', () => {
|
||||
it('returns the counts from the usage chunk', async () => {
|
||||
stream = createStreamMock();
|
||||
stream.write(`data: ${JSON.stringify(chunk)}`);
|
||||
stream.write(`data: ${JSON.stringify(usageChunk)}`);
|
||||
stream.write('data: [DONE]');
|
||||
stream.complete();
|
||||
|
||||
describe('without function tokens', () => {
|
||||
beforeEach(async () => {
|
||||
tokens = await getTokenCountFromOpenAIStream({
|
||||
responseStream: stream.transform,
|
||||
logger,
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
});
|
||||
|
||||
it('counts the prompt tokens', () => {
|
||||
expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT);
|
||||
expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT);
|
||||
expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT);
|
||||
expect(tokens).toEqual({
|
||||
prompt: usageChunk.usage.prompt_tokens,
|
||||
completion: usageChunk.usage.completion_tokens,
|
||||
total: usageChunk.usage.total_tokens,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('with function tokens', () => {
|
||||
describe('without usage chunk', () => {
|
||||
beforeEach(async () => {
|
||||
tokens = await getTokenCountFromOpenAIStream({
|
||||
responseStream: stream.transform,
|
||||
logger,
|
||||
body: JSON.stringify({
|
||||
...body,
|
||||
functions: [
|
||||
{
|
||||
name: 'my_function',
|
||||
description: 'My function description',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
my_property: {
|
||||
type: 'boolean',
|
||||
description: 'My function property',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
stream.write('data: [DONE]');
|
||||
stream.complete();
|
||||
});
|
||||
|
||||
describe('without function tokens', () => {
|
||||
beforeEach(async () => {
|
||||
tokens = await getTokenCountFromOpenAIStream({
|
||||
responseStream: stream.transform,
|
||||
logger,
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
});
|
||||
|
||||
it('counts the prompt tokens', () => {
|
||||
expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT);
|
||||
expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT);
|
||||
expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT);
|
||||
});
|
||||
});
|
||||
|
||||
it('counts the function tokens', () => {
|
||||
expect(tokens.prompt).toBeGreaterThan(PROMPT_TOKEN_COUNT);
|
||||
describe('with function tokens', () => {
|
||||
beforeEach(async () => {
|
||||
tokens = await getTokenCountFromOpenAIStream({
|
||||
responseStream: stream.transform,
|
||||
logger,
|
||||
body: JSON.stringify({
|
||||
...body,
|
||||
functions: [
|
||||
{
|
||||
name: 'my_function',
|
||||
description: 'My function description',
|
||||
parameters: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
my_property: {
|
||||
type: 'boolean',
|
||||
description: 'My function property',
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
it('counts the function tokens', () => {
|
||||
expect(tokens.prompt).toBeGreaterThan(PROMPT_TOKEN_COUNT);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -25,9 +25,91 @@ export async function getTokenCountFromOpenAIStream({
|
|||
prompt: number;
|
||||
completion: number;
|
||||
}> {
|
||||
const chatCompletionRequest = JSON.parse(
|
||||
body
|
||||
) as OpenAI.ChatCompletionCreateParams.ChatCompletionCreateParamsStreaming;
|
||||
let responseBody = '';
|
||||
|
||||
responseStream.on('data', (chunk: string) => {
|
||||
responseBody += chunk.toString();
|
||||
});
|
||||
|
||||
try {
|
||||
await finished(responseStream);
|
||||
} catch (e) {
|
||||
logger.error('An error occurred while calculating streaming response tokens');
|
||||
}
|
||||
|
||||
let completionUsage: OpenAI.CompletionUsage | undefined;
|
||||
|
||||
const response: ParsedResponse = responseBody
|
||||
.split('\n')
|
||||
.filter((line) => {
|
||||
return line.startsWith('data: ') && !line.endsWith('[DONE]');
|
||||
})
|
||||
.map((line) => {
|
||||
return JSON.parse(line.replace('data: ', ''));
|
||||
})
|
||||
.filter((line): line is OpenAI.ChatCompletionChunk => {
|
||||
return 'object' in line && line.object === 'chat.completion.chunk';
|
||||
})
|
||||
.reduce(
|
||||
(prev, line) => {
|
||||
if (line.usage) {
|
||||
completionUsage = line.usage;
|
||||
}
|
||||
if (line.choices?.length) {
|
||||
const msg = line.choices[0].delta!;
|
||||
prev.content += msg.content || '';
|
||||
prev.function_call.name += msg.function_call?.name || '';
|
||||
prev.function_call.arguments += msg.function_call?.arguments || '';
|
||||
}
|
||||
return prev;
|
||||
},
|
||||
{ content: '', function_call: { name: '', arguments: '' } }
|
||||
);
|
||||
|
||||
// not all openAI compatible providers emit completion chunk, so we still have to support
|
||||
// manually counting the tokens
|
||||
if (completionUsage) {
|
||||
return {
|
||||
prompt: completionUsage.prompt_tokens,
|
||||
completion: completionUsage.completion_tokens,
|
||||
total: completionUsage.total_tokens,
|
||||
};
|
||||
} else {
|
||||
const promptTokens = manuallyCountPromptTokens(body);
|
||||
const completionTokens = manuallyCountCompletionTokens(response);
|
||||
return {
|
||||
prompt: promptTokens,
|
||||
completion: completionTokens,
|
||||
total: promptTokens + completionTokens,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
interface ParsedResponse {
|
||||
content: string;
|
||||
function_call: {
|
||||
name: string;
|
||||
arguments: string;
|
||||
};
|
||||
}
|
||||
|
||||
const manuallyCountCompletionTokens = (response: ParsedResponse) => {
|
||||
return encode(
|
||||
JSON.stringify(
|
||||
omitBy(
|
||||
{
|
||||
content: response.content || undefined,
|
||||
function_call: response.function_call.name ? response.function_call : undefined,
|
||||
},
|
||||
isEmpty
|
||||
)
|
||||
)
|
||||
).length;
|
||||
};
|
||||
|
||||
const manuallyCountPromptTokens = (requestBody: string) => {
|
||||
const chatCompletionRequest: OpenAI.ChatCompletionCreateParams.ChatCompletionCreateParamsStreaming =
|
||||
JSON.parse(requestBody);
|
||||
|
||||
// per https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
const tokensFromMessages = encode(
|
||||
|
@ -60,67 +142,5 @@ export async function getTokenCountFromOpenAIStream({
|
|||
).length
|
||||
: 0;
|
||||
|
||||
const promptTokens = tokensFromMessages + tokensFromFunctions;
|
||||
|
||||
let responseBody: string = '';
|
||||
|
||||
responseStream.on('data', (chunk: string) => {
|
||||
responseBody += chunk.toString();
|
||||
});
|
||||
|
||||
try {
|
||||
await finished(responseStream);
|
||||
} catch (e) {
|
||||
logger.error('An error occurred while calculating streaming response tokens');
|
||||
}
|
||||
|
||||
const response = responseBody
|
||||
.split('\n')
|
||||
.filter((line) => {
|
||||
return line.startsWith('data: ') && !line.endsWith('[DONE]');
|
||||
})
|
||||
.map((line) => {
|
||||
return JSON.parse(line.replace('data: ', ''));
|
||||
})
|
||||
.filter(
|
||||
(
|
||||
line
|
||||
): line is {
|
||||
choices: Array<{
|
||||
delta: { content?: string; function_call?: { name?: string; arguments: string } };
|
||||
}>;
|
||||
} => {
|
||||
return (
|
||||
'object' in line && line.object === 'chat.completion.chunk' && line.choices.length > 0
|
||||
);
|
||||
}
|
||||
)
|
||||
.reduce(
|
||||
(prev, line) => {
|
||||
const msg = line.choices[0].delta!;
|
||||
prev.content += msg.content || '';
|
||||
prev.function_call.name += msg.function_call?.name || '';
|
||||
prev.function_call.arguments += msg.function_call?.arguments || '';
|
||||
return prev;
|
||||
},
|
||||
{ content: '', function_call: { name: '', arguments: '' } }
|
||||
);
|
||||
|
||||
const completionTokens = encode(
|
||||
JSON.stringify(
|
||||
omitBy(
|
||||
{
|
||||
content: response.content || undefined,
|
||||
function_call: response.function_call.name ? response.function_call : undefined,
|
||||
},
|
||||
isEmpty
|
||||
)
|
||||
)
|
||||
).length;
|
||||
|
||||
return {
|
||||
prompt: promptTokens,
|
||||
completion: completionTokens,
|
||||
total: promptTokens + completionTokens,
|
||||
};
|
||||
}
|
||||
return tokensFromMessages + tokensFromFunctions;
|
||||
};
|
||||
|
|
|
@ -21,17 +21,19 @@ function createOpenAIChunk({
|
|||
delta,
|
||||
usage,
|
||||
}: {
|
||||
delta: OpenAI.ChatCompletionChunk['choices'][number]['delta'];
|
||||
delta?: OpenAI.ChatCompletionChunk['choices'][number]['delta'];
|
||||
usage?: OpenAI.ChatCompletionChunk['usage'];
|
||||
}): OpenAI.ChatCompletionChunk {
|
||||
return {
|
||||
choices: [
|
||||
{
|
||||
finish_reason: null,
|
||||
index: 0,
|
||||
delta,
|
||||
},
|
||||
],
|
||||
choices: delta
|
||||
? [
|
||||
{
|
||||
finish_reason: null,
|
||||
index: 0,
|
||||
delta,
|
||||
},
|
||||
]
|
||||
: [],
|
||||
created: new Date().getTime(),
|
||||
id: v4(),
|
||||
model: 'gpt-4o',
|
||||
|
@ -313,7 +315,7 @@ describe('openAIAdapter', () => {
|
|||
]);
|
||||
});
|
||||
|
||||
it('emits token events', async () => {
|
||||
it('emits chunk events with tool calls', async () => {
|
||||
const response$ = openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
|
@ -375,5 +377,55 @@ describe('openAIAdapter', () => {
|
|||
},
|
||||
]);
|
||||
});
|
||||
|
||||
it('emits token count events', async () => {
|
||||
const response$ = openAIAdapter.chatComplete({
|
||||
...defaultArgs,
|
||||
messages: [
|
||||
{
|
||||
role: MessageRole.User,
|
||||
content: 'Hello',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
delta: {
|
||||
content: 'chunk',
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.next(
|
||||
createOpenAIChunk({
|
||||
usage: {
|
||||
prompt_tokens: 50,
|
||||
completion_tokens: 100,
|
||||
total_tokens: 150,
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
source$.complete();
|
||||
|
||||
const allChunks = await lastValueFrom(response$.pipe(toArray()));
|
||||
|
||||
expect(allChunks).toEqual([
|
||||
{
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
content: 'chunk',
|
||||
tool_calls: [],
|
||||
},
|
||||
{
|
||||
type: ChatCompletionEventType.ChatCompletionTokenCount,
|
||||
tokens: {
|
||||
prompt: 50,
|
||||
completion: 100,
|
||||
total: 150,
|
||||
},
|
||||
},
|
||||
]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
import OpenAI from 'openai';
|
||||
import type OpenAI from 'openai';
|
||||
import type {
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionMessageParam,
|
||||
|
@ -13,22 +13,33 @@ import type {
|
|||
ChatCompletionToolMessageParam,
|
||||
ChatCompletionUserMessageParam,
|
||||
} from 'openai/resources';
|
||||
import { filter, from, map, switchMap, tap, throwError, identity } from 'rxjs';
|
||||
import { Readable, isReadable } from 'stream';
|
||||
import {
|
||||
filter,
|
||||
from,
|
||||
identity,
|
||||
map,
|
||||
mergeMap,
|
||||
Observable,
|
||||
switchMap,
|
||||
tap,
|
||||
throwError,
|
||||
} from 'rxjs';
|
||||
import { isReadable, Readable } from 'stream';
|
||||
import {
|
||||
ChatCompletionChunkEvent,
|
||||
ChatCompletionEventType,
|
||||
ChatCompletionTokenCountEvent,
|
||||
createInferenceInternalError,
|
||||
Message,
|
||||
MessageRole,
|
||||
ToolOptions,
|
||||
createInferenceInternalError,
|
||||
} from '@kbn/inference-common';
|
||||
import { createTokenLimitReachedError } from '../../errors';
|
||||
import { eventSourceStreamIntoObservable } from '../../../util/event_source_stream_into_observable';
|
||||
import type { InferenceConnectorAdapter } from '../../types';
|
||||
import {
|
||||
wrapWithSimulatedFunctionCalling,
|
||||
parseInlineFunctionCalls,
|
||||
wrapWithSimulatedFunctionCalling,
|
||||
} from '../../simulated_function_calling';
|
||||
|
||||
export const openAIAdapter: InferenceConnectorAdapter = {
|
||||
|
@ -92,34 +103,57 @@ export const openAIAdapter: InferenceConnectorAdapter = {
|
|||
throw createTokenLimitReachedError();
|
||||
}
|
||||
}),
|
||||
filter(
|
||||
(line): line is OpenAI.ChatCompletionChunk =>
|
||||
'object' in line && line.object === 'chat.completion.chunk' && line.choices.length > 0
|
||||
),
|
||||
map((chunk): ChatCompletionChunkEvent => {
|
||||
const delta = chunk.choices[0].delta;
|
||||
|
||||
return {
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
content: delta.content ?? '',
|
||||
tool_calls:
|
||||
delta.tool_calls?.map((toolCall) => {
|
||||
return {
|
||||
function: {
|
||||
name: toolCall.function?.name ?? '',
|
||||
arguments: toolCall.function?.arguments ?? '',
|
||||
},
|
||||
toolCallId: toolCall.id ?? '',
|
||||
index: toolCall.index,
|
||||
};
|
||||
}) ?? [],
|
||||
};
|
||||
filter((line): line is OpenAI.ChatCompletionChunk => {
|
||||
return 'object' in line && line.object === 'chat.completion.chunk';
|
||||
}),
|
||||
mergeMap((chunk): Observable<ChatCompletionChunkEvent | ChatCompletionTokenCountEvent> => {
|
||||
const events: Array<ChatCompletionChunkEvent | ChatCompletionTokenCountEvent> = [];
|
||||
if (chunk.usage) {
|
||||
events.push(tokenCountFromOpenAI(chunk.usage));
|
||||
}
|
||||
if (chunk.choices?.length) {
|
||||
events.push(chunkFromOpenAI(chunk));
|
||||
}
|
||||
return from(events);
|
||||
}),
|
||||
simulatedFunctionCalling ? parseInlineFunctionCalls({ logger }) : identity
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
function chunkFromOpenAI(chunk: OpenAI.ChatCompletionChunk): ChatCompletionChunkEvent {
|
||||
const delta = chunk.choices[0].delta;
|
||||
|
||||
return {
|
||||
type: ChatCompletionEventType.ChatCompletionChunk,
|
||||
content: delta.content ?? '',
|
||||
tool_calls:
|
||||
delta.tool_calls?.map((toolCall) => {
|
||||
return {
|
||||
function: {
|
||||
name: toolCall.function?.name ?? '',
|
||||
arguments: toolCall.function?.arguments ?? '',
|
||||
},
|
||||
toolCallId: toolCall.id ?? '',
|
||||
index: toolCall.index,
|
||||
};
|
||||
}) ?? [],
|
||||
};
|
||||
}
|
||||
|
||||
function tokenCountFromOpenAI(
|
||||
completionUsage: OpenAI.CompletionUsage
|
||||
): ChatCompletionTokenCountEvent {
|
||||
return {
|
||||
type: ChatCompletionEventType.ChatCompletionTokenCount,
|
||||
tokens: {
|
||||
completion: completionUsage.completion_tokens,
|
||||
prompt: completionUsage.prompt_tokens,
|
||||
total: completionUsage.total_tokens,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function toolsToOpenAI(tools: ToolOptions['tools']): OpenAI.ChatCompletionCreateParams['tools'] {
|
||||
return tools
|
||||
? Object.entries(tools).map(([toolName, { description, schema }]) => {
|
||||
|
|
|
@ -101,9 +101,50 @@ describe('Azure Open AI Utils', () => {
|
|||
};
|
||||
[chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
|
||||
const sanitizedBodyString = getRequestWithStreamOption(url, JSON.stringify(body), true);
|
||||
expect(sanitizedBodyString).toEqual(
|
||||
`{\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],\"stream\":true}`
|
||||
);
|
||||
expect(JSON.parse(sanitizedBodyString)).toEqual({
|
||||
messages: [{ content: 'This is a test', role: 'user' }],
|
||||
stream: true,
|
||||
stream_options: {
|
||||
include_usage: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
it('sets stream_options when stream is true', () => {
|
||||
const body = {
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'This is a test',
|
||||
},
|
||||
],
|
||||
};
|
||||
[chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
|
||||
const sanitizedBodyString = getRequestWithStreamOption(url, JSON.stringify(body), true);
|
||||
expect(JSON.parse(sanitizedBodyString)).toEqual({
|
||||
messages: [{ content: 'This is a test', role: 'user' }],
|
||||
stream: true,
|
||||
stream_options: {
|
||||
include_usage: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
it('does not sets stream_options when stream is false', () => {
|
||||
const body = {
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'This is a test',
|
||||
},
|
||||
],
|
||||
};
|
||||
[chatUrl, completionUrl, completionExtensionsUrl].forEach((url: string) => {
|
||||
const sanitizedBodyString = getRequestWithStreamOption(url, JSON.stringify(body), false);
|
||||
expect(JSON.parse(sanitizedBodyString)).toEqual({
|
||||
messages: [{ content: 'This is a test', role: 'user' }],
|
||||
stream: false,
|
||||
});
|
||||
});
|
||||
});
|
||||
it('overrides stream parameter if defined in body', () => {
|
||||
|
|
|
@ -48,6 +48,11 @@ export const getRequestWithStreamOption = (url: string, body: string, stream: bo
|
|||
const jsonBody = JSON.parse(body);
|
||||
if (jsonBody) {
|
||||
jsonBody.stream = stream;
|
||||
if (stream) {
|
||||
jsonBody.stream_options = {
|
||||
include_usage: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return JSON.stringify(jsonBody);
|
||||
|
|
|
@ -118,6 +118,31 @@ describe('Open AI Utils', () => {
|
|||
],
|
||||
};
|
||||
|
||||
[OPENAI_CHAT_URL, OPENAI_LEGACY_COMPLETION_URL].forEach((url: string) => {
|
||||
const sanitizedBodyString = getRequestWithStreamOption(
|
||||
url,
|
||||
JSON.stringify(body),
|
||||
false,
|
||||
DEFAULT_OPENAI_MODEL
|
||||
);
|
||||
expect(JSON.parse(sanitizedBodyString)).toEqual({
|
||||
messages: [{ content: 'This is a test', role: 'user' }],
|
||||
model: 'gpt-4',
|
||||
stream: false,
|
||||
});
|
||||
});
|
||||
});
|
||||
it('sets stream_options when stream is true', () => {
|
||||
const body = {
|
||||
model: 'gpt-4',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'This is a test',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
[OPENAI_CHAT_URL, OPENAI_LEGACY_COMPLETION_URL].forEach((url: string) => {
|
||||
const sanitizedBodyString = getRequestWithStreamOption(
|
||||
url,
|
||||
|
@ -125,9 +150,39 @@ describe('Open AI Utils', () => {
|
|||
true,
|
||||
DEFAULT_OPENAI_MODEL
|
||||
);
|
||||
expect(sanitizedBodyString).toEqual(
|
||||
`{\"model\":\"gpt-4\",\"messages\":[{\"role\":\"user\",\"content\":\"This is a test\"}],\"stream\":true}`
|
||||
expect(JSON.parse(sanitizedBodyString)).toEqual({
|
||||
messages: [{ content: 'This is a test', role: 'user' }],
|
||||
model: 'gpt-4',
|
||||
stream: true,
|
||||
stream_options: {
|
||||
include_usage: true,
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
it('does not set stream_options when stream is false', () => {
|
||||
const body = {
|
||||
model: 'gpt-4',
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content: 'This is a test',
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
[OPENAI_CHAT_URL, OPENAI_LEGACY_COMPLETION_URL].forEach((url: string) => {
|
||||
const sanitizedBodyString = getRequestWithStreamOption(
|
||||
url,
|
||||
JSON.stringify(body),
|
||||
false,
|
||||
DEFAULT_OPENAI_MODEL
|
||||
);
|
||||
expect(JSON.parse(sanitizedBodyString)).toEqual({
|
||||
messages: [{ content: 'This is a test', role: 'user' }],
|
||||
model: 'gpt-4',
|
||||
stream: false,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -182,6 +237,7 @@ describe('Open AI Utils', () => {
|
|||
expect(sanitizedBodyString).toEqual(bodyString);
|
||||
});
|
||||
});
|
||||
|
||||
describe('removeEndpointFromUrl', () => {
|
||||
test('removes "/chat/completions" from the end of the URL', () => {
|
||||
const originalUrl = 'https://api.openai.com/v1/chat/completions';
|
||||
|
|
|
@ -38,6 +38,11 @@ export const getRequestWithStreamOption = (
|
|||
if (jsonBody) {
|
||||
if (APIS_ALLOWING_STREAMING.has(url)) {
|
||||
jsonBody.stream = stream;
|
||||
if (stream) {
|
||||
jsonBody.stream_options = {
|
||||
include_usage: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
jsonBody.model = jsonBody.model || defaultModel;
|
||||
}
|
||||
|
|
|
@ -292,6 +292,7 @@ describe('OpenAIConnector', () => {
|
|||
data: JSON.stringify({
|
||||
...sampleOpenAiBody,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
model: DEFAULT_OPENAI_MODEL,
|
||||
}),
|
||||
headers: {
|
||||
|
@ -338,6 +339,7 @@ describe('OpenAIConnector', () => {
|
|||
data: JSON.stringify({
|
||||
...body,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}),
|
||||
headers: {
|
||||
Authorization: 'Bearer 123',
|
||||
|
@ -397,6 +399,7 @@ describe('OpenAIConnector', () => {
|
|||
data: JSON.stringify({
|
||||
...sampleOpenAiBody,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
model: DEFAULT_OPENAI_MODEL,
|
||||
}),
|
||||
headers: {
|
||||
|
@ -422,6 +425,7 @@ describe('OpenAIConnector', () => {
|
|||
data: JSON.stringify({
|
||||
...sampleOpenAiBody,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
model: DEFAULT_OPENAI_MODEL,
|
||||
}),
|
||||
headers: {
|
||||
|
@ -448,6 +452,7 @@ describe('OpenAIConnector', () => {
|
|||
data: JSON.stringify({
|
||||
...sampleOpenAiBody,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
model: DEFAULT_OPENAI_MODEL,
|
||||
}),
|
||||
headers: {
|
||||
|
@ -1274,7 +1279,11 @@ describe('OpenAIConnector', () => {
|
|||
url: 'https://My-test-resource-123.openai.azure.com/openai/deployments/NEW-DEPLOYMENT-321/chat/completions?api-version=2023-05-15',
|
||||
method: 'post',
|
||||
responseSchema: StreamingResponseSchema,
|
||||
data: JSON.stringify({ ...sampleAzureAiBody, stream: true }),
|
||||
data: JSON.stringify({
|
||||
...sampleAzureAiBody,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}),
|
||||
headers: {
|
||||
'api-key': '123',
|
||||
'content-type': 'application/json',
|
||||
|
@ -1314,6 +1323,7 @@ describe('OpenAIConnector', () => {
|
|||
data: JSON.stringify({
|
||||
...body,
|
||||
stream: true,
|
||||
stream_options: { include_usage: true },
|
||||
}),
|
||||
headers: {
|
||||
'api-key': '123',
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue