mirror of
https://github.com/elastic/kibana.git
synced 2025-04-24 01:38:56 -04:00
[GenAI Connectors] Fix AbortSignal implementation (#180855)
## Bugs Fixed
1. The OpenAI `invokeAI` method did not properly handle `signal`
2. Bedrock did not have a `signal` implementation at all 😳
## Summary
In my [LangChain streaming
PR](https://github.com/elastic/kibana/pull/174126), I poorly implemented
a fix to stop the stream on the server when "Stop generating..." was hit
on the client. I did this by piping through an `AbortSignal` to
`invokeStream`/`invokeAsyncIterator` subactions. However, in the
`invokeAI` subaction I did not properly remove `signal` before
`JSON.strinigfy`ing the body, so the below error was happening in the
security non-streaming implementation. Additionally, for Bedrock I
somehow only implemented `signal` in part of the type and nothing else,
so token tracking would be off when Stop generating button is hit 🤦
<img width="1376" alt="Screenshot 2024-04-15 at 2 00 38 PM"
src="e57241d9
-9fd2-4dd3-bb3a-72a7c61a3d4b">
## To test
1. Turn off streaming in the Security AI Assistant and select an OpenAI
connector (LangChain off)
3. Send a message
4. Ensure expected results (prior the above error would occur)
The test of the Bedrock connector will be harder to confirm. Where the
issue would show up would be subtle, in the token counter. Before I
implemented the signal in the Bedrock connector, if you ask Bedrock to
repeat a word 100 times with streaming enabled, and then hit "Stop
generating..." after 10 words, you would see a token count for
`completion_tokens` be equivalent to ~100 tokens as the full response
would have "streamed" on the server. After this bug fix, if you hit
"Stop generating..." after 10 words, you will see a token count for
`completion_tokens` be equivalent to ~15 tokens as it takes a second for
the `abort()` to reach the server. To be clear, this bug would not have
shown in persistent storage because we call abort in
`handleStreamStorage` ASAP instead of relying on axios to complete its
abort.
This commit is contained in:
parent
b53624d472
commit
7bd8815301
8 changed files with 118 additions and 12 deletions
|
@ -45,6 +45,19 @@ Object {
|
|||
],
|
||||
"type": "string",
|
||||
},
|
||||
"signal": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "any",
|
||||
},
|
||||
},
|
||||
"preferences": Object {
|
||||
"stripUnknown": Object {
|
||||
|
@ -134,6 +147,19 @@ Object {
|
|||
],
|
||||
"type": "string",
|
||||
},
|
||||
"signal": Object {
|
||||
"flags": Object {
|
||||
"default": [Function],
|
||||
"error": [Function],
|
||||
"presence": "optional",
|
||||
},
|
||||
"metas": Array [
|
||||
Object {
|
||||
"x-oas-optional": true,
|
||||
},
|
||||
],
|
||||
"type": "any",
|
||||
},
|
||||
},
|
||||
"preferences": Object {
|
||||
"stripUnknown": Object {
|
||||
|
|
|
@ -22,6 +22,8 @@ export const SecretsSchema = schema.object({
|
|||
export const RunActionParamsSchema = schema.object({
|
||||
body: schema.string(),
|
||||
model: schema.maybe(schema.string()),
|
||||
// abort signal from client
|
||||
signal: schema.maybe(schema.any()),
|
||||
});
|
||||
|
||||
export const InvokeAIActionParamsSchema = schema.object({
|
||||
|
@ -43,11 +45,6 @@ export const InvokeAIActionResponseSchema = schema.object({
|
|||
message: schema.string(),
|
||||
});
|
||||
|
||||
export const StreamActionParamsSchema = schema.object({
|
||||
body: schema.string(),
|
||||
model: schema.maybe(schema.string()),
|
||||
});
|
||||
|
||||
export const RunApiLatestResponseSchema = schema.object(
|
||||
{
|
||||
stop_reason: schema.maybe(schema.string()),
|
||||
|
|
|
@ -15,7 +15,6 @@ import {
|
|||
RunActionResponseSchema,
|
||||
InvokeAIActionParamsSchema,
|
||||
InvokeAIActionResponseSchema,
|
||||
StreamActionParamsSchema,
|
||||
StreamingResponseSchema,
|
||||
RunApiLatestResponseSchema,
|
||||
} from './schema';
|
||||
|
@ -27,7 +26,6 @@ export type InvokeAIActionParams = TypeOf<typeof InvokeAIActionParamsSchema>;
|
|||
export type InvokeAIActionResponse = TypeOf<typeof InvokeAIActionResponseSchema>;
|
||||
export type RunApiLatestResponse = TypeOf<typeof RunApiLatestResponseSchema>;
|
||||
export type RunActionResponse = TypeOf<typeof RunActionResponseSchema>;
|
||||
export type StreamActionParams = TypeOf<typeof StreamActionParamsSchema>;
|
||||
export type StreamingResponse = TypeOf<typeof StreamingResponseSchema>;
|
||||
export type DashboardActionParams = TypeOf<typeof DashboardActionParamsSchema>;
|
||||
export type DashboardActionResponse = TypeOf<typeof DashboardActionResponseSchema>;
|
||||
|
|
|
@ -28,6 +28,8 @@ export const SecretsSchema = schema.object({ apiKey: schema.string() });
|
|||
// Run action schema
|
||||
export const RunActionParamsSchema = schema.object({
|
||||
body: schema.string(),
|
||||
// abort signal from client
|
||||
signal: schema.maybe(schema.any()),
|
||||
});
|
||||
|
||||
const AIMessage = schema.object({
|
||||
|
|
|
@ -200,6 +200,21 @@ describe('BedrockConnector', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('signal is properly passed to streamApi', async () => {
|
||||
const signal = jest.fn();
|
||||
await connector.invokeStream({ ...aiAssistantBody, signal });
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
signed: true,
|
||||
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke-with-response-stream`,
|
||||
method: 'post',
|
||||
responseSchema: StreamingResponseSchema,
|
||||
responseType: 'stream',
|
||||
data: JSON.stringify({ ...JSON.parse(DEFAULT_BODY), temperature: 0 }),
|
||||
signal,
|
||||
});
|
||||
});
|
||||
|
||||
it('ensureMessageFormat - formats messages from user, assistant, and system', async () => {
|
||||
await connector.invokeStream({
|
||||
messages: [
|
||||
|
@ -502,7 +517,25 @@ describe('BedrockConnector', () => {
|
|||
});
|
||||
expect(response.message).toEqual(mockResponseString);
|
||||
});
|
||||
it('signal is properly passed to runApi', async () => {
|
||||
const signal = jest.fn();
|
||||
await connector.invokeAI({ ...aiAssistantBody, signal });
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
signed: true,
|
||||
timeout: 120000,
|
||||
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
|
||||
method: 'post',
|
||||
responseSchema: RunApiLatestResponseSchema,
|
||||
data: JSON.stringify({
|
||||
...JSON.parse(DEFAULT_BODY),
|
||||
messages: [{ content: 'Hello world', role: 'user' }],
|
||||
max_tokens: DEFAULT_TOKEN_LIMIT,
|
||||
temperature: 0,
|
||||
}),
|
||||
signal,
|
||||
});
|
||||
});
|
||||
it('errors during API calls are properly handled', async () => {
|
||||
// @ts-ignore
|
||||
connector.request = mockError;
|
||||
|
|
|
@ -26,7 +26,6 @@ import {
|
|||
RunActionResponse,
|
||||
InvokeAIActionParams,
|
||||
InvokeAIActionResponse,
|
||||
StreamActionParams,
|
||||
RunApiLatestResponse,
|
||||
} from '../../../common/bedrock/types';
|
||||
import { SUB_ACTION, DEFAULT_TOKEN_LIMIT } from '../../../common/bedrock/constants';
|
||||
|
@ -204,7 +203,11 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
* @param body The stringified request body to be sent in the POST request.
|
||||
* @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used.
|
||||
*/
|
||||
public async runApi({ body, model: reqModel }: RunActionParams): Promise<RunActionResponse> {
|
||||
public async runApi({
|
||||
body,
|
||||
model: reqModel,
|
||||
signal,
|
||||
}: RunActionParams): Promise<RunActionResponse> {
|
||||
// set model on per request basis
|
||||
const currentModel = reqModel ?? this.model;
|
||||
const path = `/model/${currentModel}/invoke`;
|
||||
|
@ -214,6 +217,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
url: `${this.url}${path}`,
|
||||
method: 'post' as Method,
|
||||
data: body,
|
||||
signal,
|
||||
// give up to 2 minutes for response
|
||||
timeout: 120000,
|
||||
};
|
||||
|
@ -235,7 +239,8 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
private async streamApi({
|
||||
body,
|
||||
model: reqModel,
|
||||
}: StreamActionParams): Promise<StreamingResponse> {
|
||||
signal,
|
||||
}: RunActionParams): Promise<StreamingResponse> {
|
||||
// set model on per request basis
|
||||
const path = `/model/${reqModel ?? this.model}/invoke-with-response-stream`;
|
||||
const signed = this.signRequest(body, path, true);
|
||||
|
@ -247,6 +252,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
responseSchema: StreamingResponseSchema,
|
||||
data: body,
|
||||
responseType: 'stream',
|
||||
signal,
|
||||
});
|
||||
|
||||
return response.data.pipe(new PassThrough());
|
||||
|
@ -266,10 +272,12 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
stopSequences,
|
||||
system,
|
||||
temperature,
|
||||
signal,
|
||||
}: InvokeAIActionParams): Promise<IncomingMessage> {
|
||||
const res = (await this.streamApi({
|
||||
body: JSON.stringify(formatBedrockBody({ messages, stopSequences, system, temperature })),
|
||||
model,
|
||||
signal,
|
||||
})) as unknown as IncomingMessage;
|
||||
return res;
|
||||
}
|
||||
|
@ -288,10 +296,12 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
|
|||
stopSequences,
|
||||
system,
|
||||
temperature,
|
||||
signal,
|
||||
}: InvokeAIActionParams): Promise<InvokeAIActionResponse> {
|
||||
const res = await this.runApi({
|
||||
body: JSON.stringify(formatBedrockBody({ messages, stopSequences, system, temperature })),
|
||||
model,
|
||||
signal,
|
||||
});
|
||||
return { message: res.completion.trim() };
|
||||
}
|
||||
|
|
|
@ -336,6 +336,25 @@ describe('OpenAIConnector', () => {
|
|||
});
|
||||
});
|
||||
|
||||
it('signal is properly passed to streamApi', async () => {
|
||||
const signal = jest.fn();
|
||||
await connector.invokeStream({ ...sampleOpenAiBody, signal });
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
url: 'https://api.openai.com/v1/chat/completions',
|
||||
method: 'post',
|
||||
responseSchema: StreamingResponseSchema,
|
||||
responseType: 'stream',
|
||||
data: JSON.stringify({ ...sampleOpenAiBody, stream: true, model: DEFAULT_OPENAI_MODEL }),
|
||||
headers: {
|
||||
Authorization: 'Bearer 123',
|
||||
'X-My-Custom-Header': 'foo',
|
||||
'content-type': 'application/json',
|
||||
},
|
||||
signal,
|
||||
});
|
||||
});
|
||||
|
||||
it('errors during API calls are properly handled', async () => {
|
||||
// @ts-ignore
|
||||
connector.request = mockError;
|
||||
|
@ -371,6 +390,25 @@ describe('OpenAIConnector', () => {
|
|||
expect(response.usage.total_tokens).toEqual(9);
|
||||
});
|
||||
|
||||
it('signal is properly passed to runApi', async () => {
|
||||
const signal = jest.fn();
|
||||
await connector.invokeAI({ ...sampleOpenAiBody, signal });
|
||||
|
||||
expect(mockRequest).toHaveBeenCalledWith({
|
||||
timeout: 120000,
|
||||
url: 'https://api.openai.com/v1/chat/completions',
|
||||
method: 'post',
|
||||
responseSchema: RunActionResponseSchema,
|
||||
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
|
||||
headers: {
|
||||
Authorization: 'Bearer 123',
|
||||
'X-My-Custom-Header': 'foo',
|
||||
'content-type': 'application/json',
|
||||
},
|
||||
signal,
|
||||
});
|
||||
});
|
||||
|
||||
it('errors during API calls are properly handled', async () => {
|
||||
// @ts-ignore
|
||||
connector.request = mockError;
|
||||
|
|
|
@ -155,7 +155,7 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
|
|||
* responsible for making a POST request to the external API endpoint and returning the response data
|
||||
* @param body The stringified request body to be sent in the POST request.
|
||||
*/
|
||||
public async runApi({ body }: RunActionParams): Promise<RunActionResponse> {
|
||||
public async runApi({ body, signal }: RunActionParams): Promise<RunActionResponse> {
|
||||
const sanitizedBody = sanitizeRequest(
|
||||
this.provider,
|
||||
this.url,
|
||||
|
@ -168,6 +168,7 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
|
|||
method: 'post',
|
||||
responseSchema: RunActionResponseSchema,
|
||||
data: sanitizedBody,
|
||||
signal,
|
||||
// give up to 2 minutes for response
|
||||
timeout: 120000,
|
||||
...axiosOptions,
|
||||
|
@ -313,7 +314,8 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
|
|||
* @returns an object with the response string and the usage object
|
||||
*/
|
||||
public async invokeAI(body: InvokeAIActionParams): Promise<InvokeAIActionResponse> {
|
||||
const res = await this.runApi({ body: JSON.stringify(body) });
|
||||
const { signal, ...rest } = body;
|
||||
const res = await this.runApi({ body: JSON.stringify(rest), signal });
|
||||
|
||||
if (res.choices && res.choices.length > 0 && res.choices[0].message?.content) {
|
||||
const result = res.choices[0].message.content.trim();
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue