[OpenAI Connector] Track token count for streaming responses (#168440)

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
This commit is contained in:
Dario Gieselaar 2023-10-18 14:04:21 +02:00 committed by GitHub
parent 85d39d0a34
commit 980e0cc704
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 458 additions and 49 deletions

View file

@ -905,6 +905,7 @@
"getopts": "^2.2.5",
"getos": "^3.1.0",
"globby": "^11.1.0",
"gpt-tokenizer": "^2.1.2",
"handlebars": "4.7.7",
"he": "^1.2.0",
"history": "^4.9.0",

View file

@ -105,7 +105,7 @@ module.exports = {
transformIgnorePatterns: [
// ignore all node_modules except monaco-editor and react-monaco-editor which requires babel transforms to handle dynamic import()
// since ESM modules are not natively supported in Jest yet (https://github.com/facebook/jest/issues/4842)
'[/\\\\]node_modules(?![\\/\\\\](byte-size|monaco-editor|monaco-yaml|vscode-languageserver-types|react-monaco-editor|d3-interpolate|d3-color|langchain|langsmith|@cfworker))[/\\\\].+\\.js$',
'[/\\\\]node_modules(?![\\/\\\\](byte-size|monaco-editor|monaco-yaml|vscode-languageserver-types|react-monaco-editor|d3-interpolate|d3-color|langchain|langsmith|@cfworker|gpt-tokenizer))[/\\\\].+\\.js$',
'packages/kbn-pm/dist/index.js',
'[/\\\\]node_modules(?![\\/\\\\](langchain|langsmith))/dist/[/\\\\].+\\.js$',
'[/\\\\]node_modules(?![\\/\\\\](langchain|langsmith))/dist/util/[/\\\\].+\\.js$',

View file

@ -22,7 +22,7 @@ module.exports = {
// An array of regexp pattern strings that are matched against, matched files will skip transformation:
transformIgnorePatterns: [
// since ESM modules are not natively supported in Jest yet (https://github.com/facebook/jest/issues/4842)
'[/\\\\]node_modules(?![\\/\\\\](langchain|langsmith))[/\\\\].+\\.js$',
'[/\\\\]node_modules(?![\\/\\\\](langchain|langsmith|gpt-tokenizer))[/\\\\].+\\.js$',
'[/\\\\]node_modules(?![\\/\\\\](langchain|langsmith))/dist/[/\\\\].+\\.js$',
'[/\\\\]node_modules(?![\\/\\\\](langchain|langsmith))/dist/util/[/\\\\].+\\.js$',
],

View file

@ -20,6 +20,8 @@ import {
asSavedObjectExecutionSource,
} from './action_execution_source';
import { securityMock } from '@kbn/security-plugin/server/mocks';
import { finished } from 'stream/promises';
import { PassThrough } from 'stream';
const actionExecutor = new ActionExecutor({ isESOCanEncrypt: true });
const services = actionsMock.createServices();
@ -1837,6 +1839,102 @@ test('writes usage data to event log for OpenAI events', async () => {
});
});
test('writes usage data to event log for streaming OpenAI events', async () => {
const executorMock = setupActionExecutorMock('.gen-ai', {
params: { schema: schema.any() },
config: { schema: schema.any() },
secrets: { schema: schema.any() },
});
const stream = new PassThrough();
executorMock.mockResolvedValue({
actionId: '1',
status: 'ok',
// @ts-ignore
data: stream,
});
await actionExecutor.execute({
...executeParams,
params: {
subActionParams: {
body: JSON.stringify({
messages: [
{
role: 'system',
content: 'System message',
},
{
role: 'user',
content: 'User message',
},
],
}),
},
},
});
expect(eventLogger.logEvent).toHaveBeenCalledTimes(1);
stream.write(
`data: ${JSON.stringify({
object: 'chat.completion.chunk',
choices: [{ delta: { content: 'Single' } }],
})}\n`
);
stream.write(`data: [DONE]`);
stream.end();
await finished(stream);
await new Promise(process.nextTick);
expect(eventLogger.logEvent).toHaveBeenCalledTimes(2);
expect(eventLogger.logEvent).toHaveBeenNthCalledWith(2, {
event: {
action: 'execute',
kind: 'action',
outcome: 'success',
},
kibana: {
action: {
execution: {
uuid: '2',
gen_ai: {
usage: {
completion_tokens: 5,
prompt_tokens: 30,
total_tokens: 35,
},
},
},
name: 'action-1',
id: '1',
},
alert: {
rule: {
execution: {
uuid: '123abc',
},
},
},
saved_objects: [
{
id: '1',
namespace: 'some-namespace',
rel: 'primary',
type: 'action',
type_id: '.gen-ai',
},
],
space_ids: ['some-namespace'],
},
message: 'action executed: .gen-ai:1: action-1',
user: { name: 'coolguy', id: '123' },
});
});
test('does not fetches actionInfo if passed as param', async () => {
const actionType: jest.Mocked<ActionType> = {
id: 'test',
@ -1898,13 +1996,16 @@ test('does not fetches actionInfo if passed as param', async () => {
);
});
function setupActionExecutorMock(actionTypeId = 'test') {
function setupActionExecutorMock(
actionTypeId = 'test',
validationOverride?: ActionType['validate']
) {
const actionType: jest.Mocked<ActionType> = {
id: 'test',
name: 'Test',
minimumLicenseRequired: 'basic',
supportedFeatureIds: ['alerting'],
validate: {
validate: validationOverride || {
config: { schema: schema.object({ bar: schema.boolean() }) },
secrets: { schema: schema.object({ baz: schema.boolean() }) },
params: { schema: schema.object({ foo: schema.boolean() }) },

View file

@ -13,6 +13,7 @@ import { EncryptedSavedObjectsClient } from '@kbn/encrypted-saved-objects-plugin
import { SpacesServiceStart } from '@kbn/spaces-plugin/server';
import { IEventLogger, SAVED_OBJECT_REL_PRIMARY } from '@kbn/event-log-plugin/server';
import { SecurityPluginStart } from '@kbn/security-plugin/server';
import { PassThrough, Readable } from 'stream';
import {
validateParams,
validateConfig,
@ -37,6 +38,7 @@ import { RelatedSavedObjects } from './related_saved_objects';
import { createActionEventLogRecordObject } from './create_action_event_log_record_object';
import { ActionExecutionError, ActionExecutionErrorReason } from './errors/action_execution_error';
import type { ActionsAuthorization } from '../authorization/actions_authorization';
import { getTokenCountFromOpenAIStream } from './get_token_count_from_openai_stream';
// 1,000,000 nanoseconds in 1 millisecond
const Millis2Nanos = 1000 * 1000;
@ -276,8 +278,6 @@ export class ActionExecutor {
}
}
eventLogger.stopTiming(event);
// allow null-ish return to indicate success
const result = rawResult || {
actionId,
@ -286,6 +286,48 @@ export class ActionExecutor {
event.event = event.event || {};
const { error, ...resultWithoutError } = result;
function completeEventLogging() {
eventLogger.stopTiming(event);
const currentUser = security?.authc.getCurrentUser(request);
event.user = event.user || {};
event.user.name = currentUser?.username;
event.user.id = currentUser?.profile_uid;
if (result.status === 'ok') {
span?.setOutcome('success');
event.event!.outcome = 'success';
event.message = `action executed: ${actionLabel}`;
} else if (result.status === 'error') {
span?.setOutcome('failure');
event.event!.outcome = 'failure';
event.message = `action execution failure: ${actionLabel}`;
event.error = event.error || {};
event.error.message = actionErrorToMessage(result);
if (result.error) {
logger.error(result.error, {
tags: [actionTypeId, actionId, 'action-run-failed'],
error: { stack_trace: result.error.stack },
});
}
logger.warn(`action execution failure: ${actionLabel}: ${event.error.message}`);
} else {
span?.setOutcome('failure');
event.event!.outcome = 'failure';
event.message = `action execution returned unexpected result: ${actionLabel}: "${result.status}"`;
event.error = event.error || {};
event.error.message = 'action execution returned unexpected result';
logger.warn(
`action execution failure: ${actionLabel}: returned unexpected result "${result.status}"`
);
}
eventLogger.logEvent(event);
}
// start openai extension
// add event.kibana.action.execution.openai to event log when OpenAI Connector is executed
if (result.status === 'ok' && actionTypeId === '.gen-ai') {
@ -310,45 +352,34 @@ export class ActionExecutor {
},
},
};
if (result.data instanceof Readable) {
getTokenCountFromOpenAIStream({
responseStream: result.data.pipe(new PassThrough()),
body: (validatedParams as { subActionParams: { body: string } }).subActionParams.body,
})
.then(({ total, prompt, completion }) => {
event.kibana!.action!.execution!.gen_ai!.usage = {
total_tokens: total,
prompt_tokens: prompt,
completion_tokens: completion,
};
})
.catch((err) => {
logger.error('Failed to calculate tokens from streaming response');
logger.error(err);
})
.finally(() => {
completeEventLogging();
});
return resultWithoutError;
}
}
// end openai extension
const currentUser = security?.authc.getCurrentUser(request);
completeEventLogging();
event.user = event.user || {};
event.user.name = currentUser?.username;
event.user.id = currentUser?.profile_uid;
if (result.status === 'ok') {
span?.setOutcome('success');
event.event.outcome = 'success';
event.message = `action executed: ${actionLabel}`;
} else if (result.status === 'error') {
span?.setOutcome('failure');
event.event.outcome = 'failure';
event.message = `action execution failure: ${actionLabel}`;
event.error = event.error || {};
event.error.message = actionErrorToMessage(result);
if (result.error) {
logger.error(result.error, {
tags: [actionTypeId, actionId, 'action-run-failed'],
error: { stack_trace: result.error.stack },
});
}
logger.warn(`action execution failure: ${actionLabel}: ${event.error.message}`);
} else {
span?.setOutcome('failure');
event.event.outcome = 'failure';
event.message = `action execution returned unexpected result: ${actionLabel}: "${result.status}"`;
event.error = event.error || {};
event.error.message = 'action execution returned unexpected result';
logger.warn(
`action execution failure: ${actionLabel}: returned unexpected result "${result.status}"`
);
}
eventLogger.logEvent(event);
const { error, ...resultWithoutError } = result;
return resultWithoutError;
}
);

View file

@ -0,0 +1,138 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Transform } from 'stream';
import { getTokenCountFromOpenAIStream } from './get_token_count_from_openai_stream';
interface StreamMock {
write: (data: string) => void;
fail: () => void;
complete: () => void;
transform: Transform;
}
function createStreamMock(): StreamMock {
const transform: Transform = new Transform({});
return {
write: (data: string) => {
transform.push(`${data}\n`);
},
fail: () => {
transform.emit('error', new Error('Stream failed'));
transform.end();
},
transform,
complete: () => {
transform.end();
},
};
}
describe('getTokenCountFromOpenAIStream', () => {
let tokens: Awaited<ReturnType<typeof getTokenCountFromOpenAIStream>>;
let stream: StreamMock;
const body = {
messages: [
{
role: 'system',
content: 'This is a system message',
},
{
role: 'user',
content: 'This is a user message',
},
],
};
const chunk = {
object: 'chat.completion.chunk',
choices: [
{
delta: {
content: 'Single',
},
},
],
};
const PROMPT_TOKEN_COUNT = 36;
const COMPLETION_TOKEN_COUNT = 5;
beforeEach(() => {
stream = createStreamMock();
stream.write(`data: ${JSON.stringify(chunk)}`);
});
describe('when a stream completes', () => {
beforeEach(async () => {
stream.write('data: [DONE]');
stream.complete();
});
describe('without function tokens', () => {
beforeEach(async () => {
tokens = await getTokenCountFromOpenAIStream({
responseStream: stream.transform,
body: JSON.stringify(body),
});
});
it('counts the prompt tokens', () => {
expect(tokens.prompt).toBe(PROMPT_TOKEN_COUNT);
expect(tokens.completion).toBe(COMPLETION_TOKEN_COUNT);
expect(tokens.total).toBe(PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT);
});
});
describe('with function tokens', () => {
beforeEach(async () => {
tokens = await getTokenCountFromOpenAIStream({
responseStream: stream.transform,
body: JSON.stringify({
...body,
functions: [
{
name: 'my_function',
description: 'My function description',
parameters: {
type: 'object',
properties: {
my_property: {
type: 'boolean',
description: 'My function property',
},
},
},
},
],
}),
});
});
it('counts the function tokens', () => {
expect(tokens.prompt).toBeGreaterThan(PROMPT_TOKEN_COUNT);
});
});
});
describe('when a stream fails', () => {
it('resolves the promise with the correct prompt tokens', async () => {
const tokenPromise = getTokenCountFromOpenAIStream({
responseStream: stream.transform,
body: JSON.stringify(body),
});
stream.fail();
await expect(tokenPromise).resolves.toEqual({
prompt: PROMPT_TOKEN_COUNT,
total: PROMPT_TOKEN_COUNT + COMPLETION_TOKEN_COUNT,
completion: COMPLETION_TOKEN_COUNT,
});
});
});
});

View file

@ -0,0 +1,119 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { encode } from 'gpt-tokenizer';
import { isEmpty, omitBy } from 'lodash';
import { Readable } from 'stream';
import { finished } from 'stream/promises';
import { CreateChatCompletionRequest } from 'openai';
export async function getTokenCountFromOpenAIStream({
responseStream,
body,
}: {
responseStream: Readable;
body: string;
}): Promise<{
total: number;
prompt: number;
completion: number;
}> {
const chatCompletionRequest = JSON.parse(body) as CreateChatCompletionRequest;
// per https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
const tokensFromMessages = encode(
chatCompletionRequest.messages
.map(
(msg) =>
`<|start|>${msg.role}\n${msg.content}\n${
msg.name
? msg.name
: msg.function_call
? msg.function_call.name + '\n' + msg.function_call.arguments
: ''
}<|end|>`
)
.join('\n')
).length;
// this is an approximation. OpenAI cuts off a function schema
// at a certain level of nesting, so their token count might
// be lower than what we are calculating here.
const tokensFromFunctions = chatCompletionRequest.functions
? encode(
chatCompletionRequest.functions
?.map(
(fn) =>
`<|start|>${fn.name}\n${fn.description}\n${JSON.stringify(fn.parameters)}<|end|>`
)
.join('\n')
).length
: 0;
const promptTokens = tokensFromMessages + tokensFromFunctions;
let responseBody: string = '';
responseStream.on('data', (chunk: string) => {
responseBody += chunk.toString();
});
try {
await finished(responseStream);
} catch {
// no need to handle this explicitly
}
const response = responseBody
.split('\n')
.filter((line) => {
return line.startsWith('data: ') && !line.endsWith('[DONE]');
})
.map((line) => {
return JSON.parse(line.replace('data: ', ''));
})
.filter(
(
line
): line is {
choices: Array<{
delta: { content?: string; function_call?: { name?: string; arguments: string } };
}>;
} => {
return 'object' in line && line.object === 'chat.completion.chunk';
}
)
.reduce(
(prev, line) => {
const msg = line.choices[0].delta!;
prev.content += msg.content || '';
prev.function_call.name += msg.function_call?.name || '';
prev.function_call.arguments += msg.function_call?.arguments || '';
return prev;
},
{ content: '', function_call: { name: '', arguments: '' } }
);
const completionTokens = encode(
JSON.stringify(
omitBy(
{
content: response.content || undefined,
function_call: response.function_call.name ? response.function_call : undefined,
},
isEmpty
)
)
).length;
return {
prompt: promptTokens,
completion: completionTokens,
total: promptTokens + completionTokens,
};
}

View file

@ -20,6 +20,7 @@ import { SavedObjectsClientContract } from '@kbn/core-saved-objects-api-server';
import { ElasticsearchClient } from '@kbn/core-elasticsearch-server';
import { finished } from 'stream/promises';
import { IncomingMessage } from 'http';
import { PassThrough } from 'stream';
import { assertURL } from './helpers/validators';
import { ActionsConfigurationUtilities } from '../actions_config';
import { SubAction, SubActionRequestParams } from './types';
@ -158,11 +159,13 @@ export abstract class SubActionConnector<Config, Secrets> {
try {
const incomingMessage = error.response.data as IncomingMessage;
incomingMessage.on('data', (chunk) => {
const pt = incomingMessage.pipe(new PassThrough());
pt.on('data', (chunk) => {
responseBody += chunk.toString();
});
await finished(incomingMessage);
await finished(pt);
error.response.data = JSON.parse(responseBody);
} catch {

View file

@ -5,10 +5,10 @@
* 2.0.
*/
import { notImplemented } from '@hapi/boom';
import { IncomingMessage } from 'http';
import * as t from 'io-ts';
import { toBooleanRt } from '@kbn/io-ts-utils';
import type { CreateChatCompletionResponse } from 'openai';
import { Readable } from 'stream';
import { createObservabilityAIAssistantServerRoute } from '../create_observability_ai_assistant_server_route';
import { messageRt } from '../runtime_types';
@ -38,7 +38,7 @@ const chatRoute = createObservabilityAIAssistantServerRoute({
}),
t.partial({ query: t.type({ stream: toBooleanRt }) }),
]),
handler: async (resources): Promise<IncomingMessage | CreateChatCompletionResponse> => {
handler: async (resources): Promise<Readable | CreateChatCompletionResponse> => {
const { request, params, service } = resources;
const client = await service.getClient({ request });

View file

@ -10,7 +10,6 @@ import type { ActionsClient } from '@kbn/actions-plugin/server';
import type { ElasticsearchClient } from '@kbn/core/server';
import type { Logger } from '@kbn/logging';
import type { PublicMethodsOf } from '@kbn/utility-types';
import type { IncomingMessage } from 'http';
import { compact, isEmpty, merge, omit } from 'lodash';
import type {
ChatCompletionFunctions,
@ -18,10 +17,11 @@ import type {
CreateChatCompletionRequest,
CreateChatCompletionResponse,
} from 'openai';
import { PassThrough, Readable } from 'stream';
import { v4 } from 'uuid';
import {
type CompatibleJSONSchema,
MessageRole,
type CompatibleJSONSchema,
type Conversation,
type ConversationCreateRequest,
type ConversationUpdateRequest,
@ -116,7 +116,7 @@ export class ObservabilityAIAssistantClient {
functions?: Array<{ name: string; description: string; parameters: CompatibleJSONSchema }>;
functionCall?: string;
stream?: TStream;
}): Promise<TStream extends false ? CreateChatCompletionResponse : IncomingMessage> => {
}): Promise<TStream extends false ? CreateChatCompletionResponse : Readable> => {
const messagesForOpenAI: ChatCompletionRequestMessage[] = compact(
messages
.filter((message) => message.message.content || message.message.function_call?.name)
@ -195,7 +195,11 @@ export class ObservabilityAIAssistantClient {
throw internal(`${executeResult?.message} - ${executeResult?.serviceMessage}`);
}
return executeResult.data as any;
const response = stream
? ((executeResult.data as Readable).pipe(new PassThrough()) as Readable)
: (executeResult.data as CreateChatCompletionResponse);
return response as any;
};
find = async (options?: { query?: string }): Promise<{ conversations: Conversation[] }> => {

View file

@ -17856,6 +17856,13 @@ got@^9.6.0:
to-readable-stream "^1.0.0"
url-parse-lax "^3.0.0"
gpt-tokenizer@^2.1.2:
version "2.1.2"
resolved "https://registry.yarnpkg.com/gpt-tokenizer/-/gpt-tokenizer-2.1.2.tgz#14f7ce424cf2309fb5be66e112d1836080c2791a"
integrity sha512-HSuI5d6uey+c7x/VzQlPfCoGrfLyAc28vxWofKbjR9PJHm0AjQGSWkKw/OJnb+8S1g7nzgRsf0WH3dK+NNWYbg==
dependencies:
rfc4648 "^1.5.2"
graceful-fs@4.X, graceful-fs@^4.1.11, graceful-fs@^4.1.15, graceful-fs@^4.1.2, graceful-fs@^4.1.6, graceful-fs@^4.1.9, graceful-fs@^4.2.0, graceful-fs@^4.2.10, graceful-fs@^4.2.11, graceful-fs@^4.2.2, graceful-fs@^4.2.4, graceful-fs@^4.2.6, graceful-fs@^4.2.8, graceful-fs@^4.2.9:
version "4.2.11"
resolved "https://registry.yarnpkg.com/graceful-fs/-/graceful-fs-4.2.11.tgz#4183e4e8bf08bb6e05bbb2f7d2e0c8f712ca40e3"
@ -26769,6 +26776,11 @@ reusify@^1.0.4:
resolved "https://registry.yarnpkg.com/reusify/-/reusify-1.0.4.tgz#90da382b1e126efc02146e90845a88db12925d76"
integrity sha512-U9nH88a3fc/ekCF1l0/UP1IosiuIjyTh7hBvXVMHYgVcfGvt897Xguj2UOLDeI5BG2m7/uwyaLVT6fbtCwTyzw==
rfc4648@^1.5.2:
version "1.5.2"
resolved "https://registry.yarnpkg.com/rfc4648/-/rfc4648-1.5.2.tgz#cf5dac417dd83e7f4debf52e3797a723c1373383"
integrity sha512-tLOizhR6YGovrEBLatX1sdcuhoSCXddw3mqNVAcKxGJ+J0hFeJ+SjeWCv5UPA/WU3YzWPPuCVYgXBKZUPGpKtg==
rfdc@^1.2.0:
version "1.3.0"
resolved "https://registry.yarnpkg.com/rfdc/-/rfdc-1.3.0.tgz#d0b7c441ab2720d05dc4cf26e01c89631d9da08b"