[GenAI Connectors] Fix AbortSignal implementation (#180855)

## Bugs Fixed

1. The OpenAI `invokeAI` method did not properly handle `signal`
2. Bedrock did not have a `signal` implementation at all 😳

## Summary

In my [LangChain streaming
PR](https://github.com/elastic/kibana/pull/174126), I poorly implemented
a fix to stop the stream on the server when "Stop generating..." was hit
on the client. I did this by piping through an `AbortSignal` to
`invokeStream`/`invokeAsyncIterator` subactions. However, in the
`invokeAI` subaction I did not properly remove `signal` before
`JSON.strinigfy`ing the body, so the below error was happening in the
security non-streaming implementation. Additionally, for Bedrock I
somehow only implemented `signal` in part of the type and nothing else,
so token tracking would be off when Stop generating button is hit 🤦

<img width="1376" alt="Screenshot 2024-04-15 at 2 00 38 PM"
src="e57241d9-9fd2-4dd3-bb3a-72a7c61a3d4b">


## To test

1. Turn off streaming in the Security AI Assistant and select an OpenAI
connector (LangChain off)
3. Send a message
4. Ensure expected results (prior the above error would occur)

The test of the Bedrock connector will be harder to confirm. Where the
issue would show up would be subtle, in the token counter. Before I
implemented the signal in the Bedrock connector, if you ask Bedrock to
repeat a word 100 times with streaming enabled, and then hit "Stop
generating..." after 10 words, you would see a token count for
`completion_tokens` be equivalent to ~100 tokens as the full response
would have "streamed" on the server. After this bug fix, if you hit
"Stop generating..." after 10 words, you will see a token count for
`completion_tokens` be equivalent to ~15 tokens as it takes a second for
the `abort()` to reach the server. To be clear, this bug would not have
shown in persistent storage because we call abort in
`handleStreamStorage` ASAP instead of relying on axios to complete its
abort.
This commit is contained in:
Steph Milovic 2024-04-16 13:23:14 -06:00 committed by GitHub
parent b53624d472
commit 7bd8815301
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 118 additions and 12 deletions

View file

@ -45,6 +45,19 @@ Object {
],
"type": "string",
},
"signal": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "any",
},
},
"preferences": Object {
"stripUnknown": Object {
@ -134,6 +147,19 @@ Object {
],
"type": "string",
},
"signal": Object {
"flags": Object {
"default": [Function],
"error": [Function],
"presence": "optional",
},
"metas": Array [
Object {
"x-oas-optional": true,
},
],
"type": "any",
},
},
"preferences": Object {
"stripUnknown": Object {

View file

@ -22,6 +22,8 @@ export const SecretsSchema = schema.object({
export const RunActionParamsSchema = schema.object({
body: schema.string(),
model: schema.maybe(schema.string()),
// abort signal from client
signal: schema.maybe(schema.any()),
});
export const InvokeAIActionParamsSchema = schema.object({
@ -43,11 +45,6 @@ export const InvokeAIActionResponseSchema = schema.object({
message: schema.string(),
});
export const StreamActionParamsSchema = schema.object({
body: schema.string(),
model: schema.maybe(schema.string()),
});
export const RunApiLatestResponseSchema = schema.object(
{
stop_reason: schema.maybe(schema.string()),

View file

@ -15,7 +15,6 @@ import {
RunActionResponseSchema,
InvokeAIActionParamsSchema,
InvokeAIActionResponseSchema,
StreamActionParamsSchema,
StreamingResponseSchema,
RunApiLatestResponseSchema,
} from './schema';
@ -27,7 +26,6 @@ export type InvokeAIActionParams = TypeOf<typeof InvokeAIActionParamsSchema>;
export type InvokeAIActionResponse = TypeOf<typeof InvokeAIActionResponseSchema>;
export type RunApiLatestResponse = TypeOf<typeof RunApiLatestResponseSchema>;
export type RunActionResponse = TypeOf<typeof RunActionResponseSchema>;
export type StreamActionParams = TypeOf<typeof StreamActionParamsSchema>;
export type StreamingResponse = TypeOf<typeof StreamingResponseSchema>;
export type DashboardActionParams = TypeOf<typeof DashboardActionParamsSchema>;
export type DashboardActionResponse = TypeOf<typeof DashboardActionResponseSchema>;

View file

@ -28,6 +28,8 @@ export const SecretsSchema = schema.object({ apiKey: schema.string() });
// Run action schema
export const RunActionParamsSchema = schema.object({
body: schema.string(),
// abort signal from client
signal: schema.maybe(schema.any()),
});
const AIMessage = schema.object({

View file

@ -200,6 +200,21 @@ describe('BedrockConnector', () => {
});
});
it('signal is properly passed to streamApi', async () => {
const signal = jest.fn();
await connector.invokeStream({ ...aiAssistantBody, signal });
expect(mockRequest).toHaveBeenCalledWith({
signed: true,
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke-with-response-stream`,
method: 'post',
responseSchema: StreamingResponseSchema,
responseType: 'stream',
data: JSON.stringify({ ...JSON.parse(DEFAULT_BODY), temperature: 0 }),
signal,
});
});
it('ensureMessageFormat - formats messages from user, assistant, and system', async () => {
await connector.invokeStream({
messages: [
@ -502,7 +517,25 @@ describe('BedrockConnector', () => {
});
expect(response.message).toEqual(mockResponseString);
});
it('signal is properly passed to runApi', async () => {
const signal = jest.fn();
await connector.invokeAI({ ...aiAssistantBody, signal });
expect(mockRequest).toHaveBeenCalledWith({
signed: true,
timeout: 120000,
url: `${DEFAULT_BEDROCK_URL}/model/${DEFAULT_BEDROCK_MODEL}/invoke`,
method: 'post',
responseSchema: RunApiLatestResponseSchema,
data: JSON.stringify({
...JSON.parse(DEFAULT_BODY),
messages: [{ content: 'Hello world', role: 'user' }],
max_tokens: DEFAULT_TOKEN_LIMIT,
temperature: 0,
}),
signal,
});
});
it('errors during API calls are properly handled', async () => {
// @ts-ignore
connector.request = mockError;

View file

@ -26,7 +26,6 @@ import {
RunActionResponse,
InvokeAIActionParams,
InvokeAIActionResponse,
StreamActionParams,
RunApiLatestResponse,
} from '../../../common/bedrock/types';
import { SUB_ACTION, DEFAULT_TOKEN_LIMIT } from '../../../common/bedrock/constants';
@ -204,7 +203,11 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
* @param body The stringified request body to be sent in the POST request.
* @param model Optional model to be used for the API request. If not provided, the default model from the connector will be used.
*/
public async runApi({ body, model: reqModel }: RunActionParams): Promise<RunActionResponse> {
public async runApi({
body,
model: reqModel,
signal,
}: RunActionParams): Promise<RunActionResponse> {
// set model on per request basis
const currentModel = reqModel ?? this.model;
const path = `/model/${currentModel}/invoke`;
@ -214,6 +217,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
url: `${this.url}${path}`,
method: 'post' as Method,
data: body,
signal,
// give up to 2 minutes for response
timeout: 120000,
};
@ -235,7 +239,8 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
private async streamApi({
body,
model: reqModel,
}: StreamActionParams): Promise<StreamingResponse> {
signal,
}: RunActionParams): Promise<StreamingResponse> {
// set model on per request basis
const path = `/model/${reqModel ?? this.model}/invoke-with-response-stream`;
const signed = this.signRequest(body, path, true);
@ -247,6 +252,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
responseSchema: StreamingResponseSchema,
data: body,
responseType: 'stream',
signal,
});
return response.data.pipe(new PassThrough());
@ -266,10 +272,12 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
stopSequences,
system,
temperature,
signal,
}: InvokeAIActionParams): Promise<IncomingMessage> {
const res = (await this.streamApi({
body: JSON.stringify(formatBedrockBody({ messages, stopSequences, system, temperature })),
model,
signal,
})) as unknown as IncomingMessage;
return res;
}
@ -288,10 +296,12 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
stopSequences,
system,
temperature,
signal,
}: InvokeAIActionParams): Promise<InvokeAIActionResponse> {
const res = await this.runApi({
body: JSON.stringify(formatBedrockBody({ messages, stopSequences, system, temperature })),
model,
signal,
});
return { message: res.completion.trim() };
}

View file

@ -336,6 +336,25 @@ describe('OpenAIConnector', () => {
});
});
it('signal is properly passed to streamApi', async () => {
const signal = jest.fn();
await connector.invokeStream({ ...sampleOpenAiBody, signal });
expect(mockRequest).toHaveBeenCalledWith({
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: StreamingResponseSchema,
responseType: 'stream',
data: JSON.stringify({ ...sampleOpenAiBody, stream: true, model: DEFAULT_OPENAI_MODEL }),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
signal,
});
});
it('errors during API calls are properly handled', async () => {
// @ts-ignore
connector.request = mockError;
@ -371,6 +390,25 @@ describe('OpenAIConnector', () => {
expect(response.usage.total_tokens).toEqual(9);
});
it('signal is properly passed to runApi', async () => {
const signal = jest.fn();
await connector.invokeAI({ ...sampleOpenAiBody, signal });
expect(mockRequest).toHaveBeenCalledWith({
timeout: 120000,
url: 'https://api.openai.com/v1/chat/completions',
method: 'post',
responseSchema: RunActionResponseSchema,
data: JSON.stringify({ ...sampleOpenAiBody, stream: false, model: DEFAULT_OPENAI_MODEL }),
headers: {
Authorization: 'Bearer 123',
'X-My-Custom-Header': 'foo',
'content-type': 'application/json',
},
signal,
});
});
it('errors during API calls are properly handled', async () => {
// @ts-ignore
connector.request = mockError;

View file

@ -155,7 +155,7 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
* responsible for making a POST request to the external API endpoint and returning the response data
* @param body The stringified request body to be sent in the POST request.
*/
public async runApi({ body }: RunActionParams): Promise<RunActionResponse> {
public async runApi({ body, signal }: RunActionParams): Promise<RunActionResponse> {
const sanitizedBody = sanitizeRequest(
this.provider,
this.url,
@ -168,6 +168,7 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
method: 'post',
responseSchema: RunActionResponseSchema,
data: sanitizedBody,
signal,
// give up to 2 minutes for response
timeout: 120000,
...axiosOptions,
@ -313,7 +314,8 @@ export class OpenAIConnector extends SubActionConnector<Config, Secrets> {
* @returns an object with the response string and the usage object
*/
public async invokeAI(body: InvokeAIActionParams): Promise<InvokeAIActionResponse> {
const res = await this.runApi({ body: JSON.stringify(body) });
const { signal, ...rest } = body;
const res = await this.runApi({ body: JSON.stringify(rest), signal });
if (res.choices && res.choices.length > 0 && res.choices[0].message?.content) {
const result = res.choices[0].message.content.trim();