[Security solution] Fix OpenAI token reporting (#169156)

This commit is contained in:
Steph Milovic 2023-10-18 12:06:03 -06:00 committed by GitHub
parent 7df3f964ce
commit 8284398023
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 147 additions and 27 deletions

View file

@ -54,7 +54,7 @@ describe('API tests', () => {
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/actions/connector/foo/_execute',
{
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"}}',
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":true}',
headers: { 'Content-Type': 'application/json' },
method: 'POST',
signal: undefined,
@ -72,12 +72,15 @@ describe('API tests', () => {
await fetchConnectorExecuteAction(testProps);
expect(mockHttp.fetch).toHaveBeenCalledWith('/api/actions/connector/foo/_execute', {
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"}}',
headers: { 'Content-Type': 'application/json' },
method: 'POST',
signal: undefined,
});
expect(mockHttp.fetch).toHaveBeenCalledWith(
'/internal/elastic_assistant/actions/connector/foo/_execute',
{
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":false}',
headers: { 'Content-Type': 'application/json' },
method: 'POST',
signal: undefined,
}
);
});
it('returns API_ERROR when the response status is not ok', async () => {

View file

@ -59,19 +59,16 @@ export const fetchConnectorExecuteAction = async ({
subActionParams: body,
subAction: 'invokeAI',
},
assistantLangChain,
};
try {
const path = assistantLangChain
? `/internal/elastic_assistant/actions/connector/${apiConfig?.connectorId}/_execute`
: `/api/actions/connector/${apiConfig?.connectorId}/_execute`;
const response = await http.fetch<{
connector_id: string;
status: string;
data: string;
service_message?: string;
}>(path, {
}>(`/internal/elastic_assistant/actions/connector/${apiConfig?.connectorId}/_execute`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',

View file

@ -5,4 +5,7 @@
* 2.0.
*/
export const mockActionResponse = 'Yes, your name is Andrew. How can I assist you further, Andrew?';
export const mockActionResponse = {
message: 'Yes, your name is Andrew. How can I assist you further, Andrew?',
usage: { prompt_tokens: 4, completion_tokens: 10, total_tokens: 14 },
};

View file

@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { get } from 'lodash/fp';
import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
import { KibanaRequest } from '@kbn/core-http-server';
import { RequestBody } from './langchain/types';
interface Props {
actions: ActionsPluginStart;
connectorId: string;
request: KibanaRequest<unknown, unknown, RequestBody>;
}
interface StaticResponse {
connector_id: string;
data: string;
status: string;
}
export const executeAction = async ({
actions,
request,
connectorId,
}: Props): Promise<StaticResponse> => {
const actionsClient = await actions.getActionsClientWithRequest(request);
const actionResult = await actionsClient.execute({
actionId: connectorId,
params: request.body.params,
});
const content = get('data.message', actionResult);
if (typeof content === 'string') {
return {
connector_id: connectorId,
data: content, // the response from the actions framework
status: 'ok',
};
}
throw new Error('Unexpected action result');
};

View file

@ -51,6 +51,7 @@ const mockRequest: KibanaRequest<unknown, unknown, RequestBody> = {
},
subAction: 'invokeAI',
},
assistantLangChain: true,
},
} as KibanaRequest<unknown, unknown, RequestBody>;
@ -72,7 +73,7 @@ describe('ActionsClientLlm', () => {
await actionsClientLlm._call(prompt); // ignore the result
expect(actionsClientLlm.getActionResultData()).toEqual(mockActionResponse);
expect(actionsClientLlm.getActionResultData()).toEqual(mockActionResponse.message);
});
});
@ -141,7 +142,7 @@ describe('ActionsClientLlm', () => {
});
it('rejects with the expected error the message has invalid content', async () => {
const invalidContent = 1234;
const invalidContent = { message: 1234 };
mockExecute.mockImplementation(() => ({
data: invalidContent,

View file

@ -92,9 +92,8 @@ export class ActionsClientLlm extends LLM {
`${LLM_TYPE}: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}`
);
}
// TODO: handle errors from the connector
const content = get('data', actionResult);
const content = get('data.message', actionResult);
if (typeof content !== 'string') {
throw new Error(

View file

@ -105,6 +105,7 @@ export const postEvaluateRoute = (
messages: [],
},
},
assistantLangChain: true,
},
};

View file

@ -19,6 +19,13 @@ import { coreMock } from '@kbn/core/server/mocks';
jest.mock('../lib/build_response', () => ({
buildResponse: jest.fn().mockImplementation((x) => x),
}));
jest.mock('../lib/executor', () => ({
executeAction: jest.fn().mockImplementation((x) => ({
connector_id: 'mock-connector-id',
data: mockActionResponse,
status: 'ok',
})),
}));
jest.mock('../lib/langchain/execute_custom_llm_chain', () => ({
callAgentExecutor: jest.fn().mockImplementation(
@ -82,6 +89,7 @@ const mockRequest = {
},
subAction: 'invokeAI',
},
assistantLangChain: true,
},
};
@ -97,7 +105,38 @@ describe('postActionsConnectorExecuteRoute', () => {
jest.clearAllMocks();
});
it('returns the expected response', async () => {
it('returns the expected response when assistantLangChain=false', async () => {
const mockRouter = {
post: jest.fn().mockImplementation(async (_, handler) => {
const result = await handler(
mockContext,
{
...mockRequest,
body: {
...mockRequest.body,
assistantLangChain: false,
},
},
mockResponse
);
expect(result).toEqual({
body: {
connector_id: 'mock-connector-id',
data: mockActionResponse,
status: 'ok',
},
});
}),
};
await postActionsConnectorExecuteRoute(
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
mockGetElser
);
});
it('returns the expected response when assistantLangChain=true', async () => {
const mockRouter = {
post: jest.fn().mockImplementation(async (_, handler) => {
const result = await handler(mockContext, mockRequest, mockResponse);

View file

@ -7,6 +7,7 @@
import { IRouter, Logger } from '@kbn/core/server';
import { transformError } from '@kbn/securitysolution-es-utils';
import { executeAction } from '../lib/executor';
import { POST_ACTIONS_CONNECTOR_EXECUTE } from '../../common/constants';
import { getLangChainMessages } from '../lib/langchain/helpers';
import { buildResponse } from '../lib/build_response';
@ -41,6 +42,14 @@ export const postActionsConnectorExecuteRoute = (
// get the actions plugin start contract from the request context:
const actions = (await context.elasticAssistant).actions;
// if not langchain, call execute action directly and return the response:
if (!request.body.assistantLangChain) {
const result = await executeAction({ actions, request, connectorId });
return response.ok({
body: result,
});
}
// get a scoped esClient for assistant memory
const esClient = (await context.core).elasticsearch.client.asCurrentUser;

View file

@ -34,6 +34,7 @@ export const PostActionsConnectorExecuteBody = t.type({
]),
subAction: t.string,
}),
assistantLangChain: t.boolean,
});
export type PostActionsConnectorExecuteBodyInputs = t.TypeOf<

View file

@ -34,7 +34,9 @@ export const InvokeAIActionParamsSchema = schema.object({
model: schema.maybe(schema.string()),
});
export const InvokeAIActionResponseSchema = schema.string();
export const InvokeAIActionResponseSchema = schema.object({
message: schema.string(),
});
export const RunActionResponseSchema = schema.object(
{

View file

@ -44,7 +44,17 @@ export const InvokeAIActionParamsSchema = schema.object({
temperature: schema.maybe(schema.number()),
});
export const InvokeAIActionResponseSchema = schema.string();
export const InvokeAIActionResponseSchema = schema.object({
message: schema.string(),
usage: schema.object(
{
prompt_tokens: schema.number(),
completion_tokens: schema.number(),
total_tokens: schema.number(),
},
{ unknowns: 'ignore' }
),
});
// Execute action schema
export const StreamActionParamsSchema = schema.object({

View file

@ -109,7 +109,7 @@ describe('BedrockConnector', () => {
stop_sequences: ['\n\nHuman:'],
}),
});
expect(response).toEqual(mockResponseString);
expect(response.message).toEqual(mockResponseString);
});
it('Properly formats messages from user, assistant, and system', async () => {
@ -148,7 +148,7 @@ describe('BedrockConnector', () => {
stop_sequences: ['\n\nHuman:'],
}),
});
expect(response).toEqual(mockResponseString);
expect(response.message).toEqual(mockResponseString);
});
it('errors during API calls are properly handled', async () => {

View file

@ -150,6 +150,6 @@ export class BedrockConnector extends SubActionConnector<Config, Secrets> {
};
const res = await this.runApi({ body: JSON.stringify(req), model });
return res.completion.trim();
return { message: res.completion.trim() };
}
}

View file

@ -37,6 +37,11 @@ describe('OpenAIConnector', () => {
index: 0,
},
],
usage: {
prompt_tokens: 4,
completion_tokens: 5,
total_tokens: 9,
},
},
};
beforeEach(() => {
@ -273,7 +278,8 @@ describe('OpenAIConnector', () => {
'content-type': 'application/json',
},
});
expect(response).toEqual(mockResponseString);
expect(response.message).toEqual(mockResponseString);
expect(response.usage.total_tokens).toEqual(9);
});
it('errors during API calls are properly handled', async () => {

View file

@ -192,9 +192,15 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
if (res.choices && res.choices.length > 0 && res.choices[0].message?.content) {
const result = res.choices[0].message.content.trim();
return result;
return { message: result, usage: res.usage };
}
return 'An error occurred sending your message. \n\nAPI Error: The response from OpenAI was in an unrecognized format.';
return {
message:
'An error occurred sending your message. \n\nAPI Error: The response from OpenAI was in an unrecognized format.',
...(res.usage
? { usage: res.usage }
: { usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 } }),
};
}
}

View file

@ -404,7 +404,7 @@ export default function bedrockTest({ getService }: FtrProviderContext) {
expect(body).to.eql({
status: 'ok',
connector_id: bedrockActionId,
data: bedrockSuccessResponse.completion,
data: { message: bedrockSuccessResponse.completion },
});
});
});