mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 09:48:58 -04:00
[Security solution] Fix OpenAI token reporting (#169156)
This commit is contained in:
parent
7df3f964ce
commit
8284398023
17 changed files with 147 additions and 27 deletions
|
@ -54,7 +54,7 @@ describe('API tests', () => {
|
|||
expect(mockHttp.fetch).toHaveBeenCalledWith(
|
||||
'/internal/elastic_assistant/actions/connector/foo/_execute',
|
||||
{
|
||||
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"}}',
|
||||
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":true}',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
method: 'POST',
|
||||
signal: undefined,
|
||||
|
@ -72,12 +72,15 @@ describe('API tests', () => {
|
|||
|
||||
await fetchConnectorExecuteAction(testProps);
|
||||
|
||||
expect(mockHttp.fetch).toHaveBeenCalledWith('/api/actions/connector/foo/_execute', {
|
||||
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"}}',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
method: 'POST',
|
||||
signal: undefined,
|
||||
});
|
||||
expect(mockHttp.fetch).toHaveBeenCalledWith(
|
||||
'/internal/elastic_assistant/actions/connector/foo/_execute',
|
||||
{
|
||||
body: '{"params":{"subActionParams":{"model":"gpt-4","messages":[{"role":"user","content":"This is a test"}],"n":1,"stop":null,"temperature":0.2},"subAction":"invokeAI"},"assistantLangChain":false}',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
method: 'POST',
|
||||
signal: undefined,
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
it('returns API_ERROR when the response status is not ok', async () => {
|
||||
|
|
|
@ -59,19 +59,16 @@ export const fetchConnectorExecuteAction = async ({
|
|||
subActionParams: body,
|
||||
subAction: 'invokeAI',
|
||||
},
|
||||
assistantLangChain,
|
||||
};
|
||||
|
||||
try {
|
||||
const path = assistantLangChain
|
||||
? `/internal/elastic_assistant/actions/connector/${apiConfig?.connectorId}/_execute`
|
||||
: `/api/actions/connector/${apiConfig?.connectorId}/_execute`;
|
||||
|
||||
const response = await http.fetch<{
|
||||
connector_id: string;
|
||||
status: string;
|
||||
data: string;
|
||||
service_message?: string;
|
||||
}>(path, {
|
||||
}>(`/internal/elastic_assistant/actions/connector/${apiConfig?.connectorId}/_execute`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
|
|
|
@ -5,4 +5,7 @@
|
|||
* 2.0.
|
||||
*/
|
||||
|
||||
export const mockActionResponse = 'Yes, your name is Andrew. How can I assist you further, Andrew?';
|
||||
export const mockActionResponse = {
|
||||
message: 'Yes, your name is Andrew. How can I assist you further, Andrew?',
|
||||
usage: { prompt_tokens: 4, completion_tokens: 10, total_tokens: 14 },
|
||||
};
|
||||
|
|
43
x-pack/plugins/elastic_assistant/server/lib/executor.ts
Normal file
43
x-pack/plugins/elastic_assistant/server/lib/executor.ts
Normal file
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
import { get } from 'lodash/fp';
|
||||
import { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server';
|
||||
import { KibanaRequest } from '@kbn/core-http-server';
|
||||
import { RequestBody } from './langchain/types';
|
||||
|
||||
interface Props {
|
||||
actions: ActionsPluginStart;
|
||||
connectorId: string;
|
||||
request: KibanaRequest<unknown, unknown, RequestBody>;
|
||||
}
|
||||
interface StaticResponse {
|
||||
connector_id: string;
|
||||
data: string;
|
||||
status: string;
|
||||
}
|
||||
|
||||
export const executeAction = async ({
|
||||
actions,
|
||||
request,
|
||||
connectorId,
|
||||
}: Props): Promise<StaticResponse> => {
|
||||
const actionsClient = await actions.getActionsClientWithRequest(request);
|
||||
const actionResult = await actionsClient.execute({
|
||||
actionId: connectorId,
|
||||
params: request.body.params,
|
||||
});
|
||||
const content = get('data.message', actionResult);
|
||||
if (typeof content === 'string') {
|
||||
return {
|
||||
connector_id: connectorId,
|
||||
data: content, // the response from the actions framework
|
||||
status: 'ok',
|
||||
};
|
||||
}
|
||||
throw new Error('Unexpected action result');
|
||||
};
|
|
@ -51,6 +51,7 @@ const mockRequest: KibanaRequest<unknown, unknown, RequestBody> = {
|
|||
},
|
||||
subAction: 'invokeAI',
|
||||
},
|
||||
assistantLangChain: true,
|
||||
},
|
||||
} as KibanaRequest<unknown, unknown, RequestBody>;
|
||||
|
||||
|
@ -72,7 +73,7 @@ describe('ActionsClientLlm', () => {
|
|||
|
||||
await actionsClientLlm._call(prompt); // ignore the result
|
||||
|
||||
expect(actionsClientLlm.getActionResultData()).toEqual(mockActionResponse);
|
||||
expect(actionsClientLlm.getActionResultData()).toEqual(mockActionResponse.message);
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -141,7 +142,7 @@ describe('ActionsClientLlm', () => {
|
|||
});
|
||||
|
||||
it('rejects with the expected error the message has invalid content', async () => {
|
||||
const invalidContent = 1234;
|
||||
const invalidContent = { message: 1234 };
|
||||
|
||||
mockExecute.mockImplementation(() => ({
|
||||
data: invalidContent,
|
||||
|
|
|
@ -92,9 +92,8 @@ export class ActionsClientLlm extends LLM {
|
|||
`${LLM_TYPE}: action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}`
|
||||
);
|
||||
}
|
||||
|
||||
// TODO: handle errors from the connector
|
||||
const content = get('data', actionResult);
|
||||
const content = get('data.message', actionResult);
|
||||
|
||||
if (typeof content !== 'string') {
|
||||
throw new Error(
|
||||
|
|
|
@ -105,6 +105,7 @@ export const postEvaluateRoute = (
|
|||
messages: [],
|
||||
},
|
||||
},
|
||||
assistantLangChain: true,
|
||||
},
|
||||
};
|
||||
|
||||
|
|
|
@ -19,6 +19,13 @@ import { coreMock } from '@kbn/core/server/mocks';
|
|||
jest.mock('../lib/build_response', () => ({
|
||||
buildResponse: jest.fn().mockImplementation((x) => x),
|
||||
}));
|
||||
jest.mock('../lib/executor', () => ({
|
||||
executeAction: jest.fn().mockImplementation((x) => ({
|
||||
connector_id: 'mock-connector-id',
|
||||
data: mockActionResponse,
|
||||
status: 'ok',
|
||||
})),
|
||||
}));
|
||||
|
||||
jest.mock('../lib/langchain/execute_custom_llm_chain', () => ({
|
||||
callAgentExecutor: jest.fn().mockImplementation(
|
||||
|
@ -82,6 +89,7 @@ const mockRequest = {
|
|||
},
|
||||
subAction: 'invokeAI',
|
||||
},
|
||||
assistantLangChain: true,
|
||||
},
|
||||
};
|
||||
|
||||
|
@ -97,7 +105,38 @@ describe('postActionsConnectorExecuteRoute', () => {
|
|||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
it('returns the expected response', async () => {
|
||||
it('returns the expected response when assistantLangChain=false', async () => {
|
||||
const mockRouter = {
|
||||
post: jest.fn().mockImplementation(async (_, handler) => {
|
||||
const result = await handler(
|
||||
mockContext,
|
||||
{
|
||||
...mockRequest,
|
||||
body: {
|
||||
...mockRequest.body,
|
||||
assistantLangChain: false,
|
||||
},
|
||||
},
|
||||
mockResponse
|
||||
);
|
||||
|
||||
expect(result).toEqual({
|
||||
body: {
|
||||
connector_id: 'mock-connector-id',
|
||||
data: mockActionResponse,
|
||||
status: 'ok',
|
||||
},
|
||||
});
|
||||
}),
|
||||
};
|
||||
|
||||
await postActionsConnectorExecuteRoute(
|
||||
mockRouter as unknown as IRouter<ElasticAssistantRequestHandlerContext>,
|
||||
mockGetElser
|
||||
);
|
||||
});
|
||||
|
||||
it('returns the expected response when assistantLangChain=true', async () => {
|
||||
const mockRouter = {
|
||||
post: jest.fn().mockImplementation(async (_, handler) => {
|
||||
const result = await handler(mockContext, mockRequest, mockResponse);
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
|
||||
import { IRouter, Logger } from '@kbn/core/server';
|
||||
import { transformError } from '@kbn/securitysolution-es-utils';
|
||||
import { executeAction } from '../lib/executor';
|
||||
import { POST_ACTIONS_CONNECTOR_EXECUTE } from '../../common/constants';
|
||||
import { getLangChainMessages } from '../lib/langchain/helpers';
|
||||
import { buildResponse } from '../lib/build_response';
|
||||
|
@ -41,6 +42,14 @@ export const postActionsConnectorExecuteRoute = (
|
|||
// get the actions plugin start contract from the request context:
|
||||
const actions = (await context.elasticAssistant).actions;
|
||||
|
||||
// if not langchain, call execute action directly and return the response:
|
||||
if (!request.body.assistantLangChain) {
|
||||
const result = await executeAction({ actions, request, connectorId });
|
||||
return response.ok({
|
||||
body: result,
|
||||
});
|
||||
}
|
||||
|
||||
// get a scoped esClient for assistant memory
|
||||
const esClient = (await context.core).elasticsearch.client.asCurrentUser;
|
||||
|
||||
|
|
|
@ -34,6 +34,7 @@ export const PostActionsConnectorExecuteBody = t.type({
|
|||
]),
|
||||
subAction: t.string,
|
||||
}),
|
||||
assistantLangChain: t.boolean,
|
||||
});
|
||||
|
||||
export type PostActionsConnectorExecuteBodyInputs = t.TypeOf<
|
||||
|
|
|
@ -34,7 +34,9 @@ export const InvokeAIActionParamsSchema = schema.object({
|
|||
model: schema.maybe(schema.string()),
|
||||
});
|
||||
|
||||
export const InvokeAIActionResponseSchema = schema.string();
|
||||
export const InvokeAIActionResponseSchema = schema.object({
|
||||
message: schema.string(),
|
||||
});
|
||||
|
||||
export const RunActionResponseSchema = schema.object(
|
||||
{
|
||||
|
|
|
@ -44,7 +44,17 @@ export const InvokeAIActionParamsSchema = schema.object({
|
|||
temperature: schema.maybe(schema.number()),
|
||||
});
|
||||
|
||||
export const InvokeAIActionResponseSchema = schema.string();
|
||||
export const InvokeAIActionResponseSchema = schema.object({
|
||||
message: schema.string(),
|
||||
usage: schema.object(
|
||||
{
|
||||
prompt_tokens: schema.number(),
|
||||
completion_tokens: schema.number(),
|
||||
total_tokens: schema.number(),
|
||||
},
|
||||
{ unknowns: 'ignore' }
|
||||
),
|
||||
});
|
||||
|
||||
// Execute action schema
|
||||
export const StreamActionParamsSchema = schema.object({
|
||||
|
|
|
@ -109,7 +109,7 @@ describe('BedrockConnector', () => {
|
|||
stop_sequences: ['\n\nHuman:'],
|
||||
}),
|
||||
});
|
||||
expect(response).toEqual(mockResponseString);
|
||||
expect(response.message).toEqual(mockResponseString);
|
||||
});
|
||||
|
||||
it('Properly formats messages from user, assistant, and system', async () => {
|
||||
|
@ -148,7 +148,7 @@ describe('BedrockConnector', () => {
|
|||
stop_sequences: ['\n\nHuman:'],
|
||||
}),
|
||||
});
|
||||
expect(response).toEqual(mockResponseString);
|
||||
expect(response.message).toEqual(mockResponseString);
|
||||
});
|
||||
|
||||
it('errors during API calls are properly handled', async () => {
|
||||
|
|
|
@ -150,6 +150,6 @@ export class BedrockConnector extends SubActionConnector<Config, Secrets> {
|
|||
};
|
||||
|
||||
const res = await this.runApi({ body: JSON.stringify(req), model });
|
||||
return res.completion.trim();
|
||||
return { message: res.completion.trim() };
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,6 +37,11 @@ describe('OpenAIConnector', () => {
|
|||
index: 0,
|
||||
},
|
||||
],
|
||||
usage: {
|
||||
prompt_tokens: 4,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 9,
|
||||
},
|
||||
},
|
||||
};
|
||||
beforeEach(() => {
|
||||
|
@ -273,7 +278,8 @@ describe('OpenAIConnector', () => {
|
|||
'content-type': 'application/json',
|
||||
},
|
||||
});
|
||||
expect(response).toEqual(mockResponseString);
|
||||
expect(response.message).toEqual(mockResponseString);
|
||||
expect(response.usage.total_tokens).toEqual(9);
|
||||
});
|
||||
|
||||
it('errors during API calls are properly handled', async () => {
|
||||
|
|
|
@ -192,9 +192,15 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
|
|||
|
||||
if (res.choices && res.choices.length > 0 && res.choices[0].message?.content) {
|
||||
const result = res.choices[0].message.content.trim();
|
||||
return result;
|
||||
return { message: result, usage: res.usage };
|
||||
}
|
||||
|
||||
return 'An error occurred sending your message. \n\nAPI Error: The response from OpenAI was in an unrecognized format.';
|
||||
return {
|
||||
message:
|
||||
'An error occurred sending your message. \n\nAPI Error: The response from OpenAI was in an unrecognized format.',
|
||||
...(res.usage
|
||||
? { usage: res.usage }
|
||||
: { usage: { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 } }),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -404,7 +404,7 @@ export default function bedrockTest({ getService }: FtrProviderContext) {
|
|||
expect(body).to.eql({
|
||||
status: 'ok',
|
||||
connector_id: bedrockActionId,
|
||||
data: bedrockSuccessResponse.completion,
|
||||
data: { message: bedrockSuccessResponse.completion },
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue