[AI Assistant] Fix OpenAI, error race condition bug (#205665)

This commit is contained in:
Steph Milovic 2025-01-07 08:12:20 -07:00 committed by GitHub
parent bccf0c99c9
commit 2c70e8651e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 108 additions and 699 deletions

View file

@ -1,149 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { executeAction, Props } from './executor';
import { PassThrough } from 'stream';
import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/actions_client.mock';
import { loggerMock } from '@kbn/logging-mocks';
import * as ParseStream from './parse_stream';
const onLlmResponse = jest.fn(async () => {}); // We need it to be a promise, or it'll crash because of missing `.catch`
const connectorId = 'testConnectorId';
const mockLogger = loggerMock.create();
const testProps: Omit<Props, 'actions'> = {
params: {
subAction: 'invokeAI',
subActionParams: { messages: [{ content: 'hello', role: 'user' }] },
},
actionTypeId: '.bedrock',
connectorId,
actionsClient: actionsClientMock.create(),
onLlmResponse,
logger: mockLogger,
};
const handleStreamStorageSpy = jest.spyOn(ParseStream, 'handleStreamStorage');
describe('executeAction', () => {
beforeEach(() => {
jest.clearAllMocks();
});
it('should execute an action and return a StaticResponse when the response from the actions framework is a string', async () => {
testProps.actionsClient.execute = jest.fn().mockResolvedValue({
data: {
message: 'Test message',
},
});
const result = await executeAction({ ...testProps });
expect(result).toEqual({
connector_id: connectorId,
data: 'Test message',
status: 'ok',
});
expect(onLlmResponse).toHaveBeenCalledWith('Test message');
});
it('should execute an action and return a Readable object when the response from the actions framework is a stream', async () => {
const readableStream = new PassThrough();
const actionsClient = actionsClientMock.create();
actionsClient.execute.mockImplementationOnce(
jest.fn().mockResolvedValue({
status: 'ok',
data: readableStream,
})
);
const result = await executeAction({ ...testProps, actionsClient });
expect(JSON.stringify(result)).toStrictEqual(
JSON.stringify(readableStream.pipe(new PassThrough()))
);
expect(handleStreamStorageSpy).toHaveBeenCalledWith({
actionTypeId: '.bedrock',
onMessageSent: onLlmResponse,
logger: mockLogger,
responseStream: readableStream,
});
});
it('should throw an error if the actions client fails to execute the action', async () => {
const actionsClient = actionsClientMock.create();
actionsClient.execute.mockRejectedValue(new Error('Failed to execute action'));
testProps.actionsClient = actionsClient;
await expect(executeAction({ ...testProps, actionsClient })).rejects.toThrowError(
'Failed to execute action'
);
});
it('should throw an error when the response from the actions framework is null or undefined', async () => {
const actionsClient = actionsClientMock.create();
actionsClient.execute.mockImplementationOnce(
jest.fn().mockResolvedValue({
data: null,
})
);
testProps.actionsClient = actionsClient;
try {
await executeAction({ ...testProps, actionsClient });
} catch (e) {
expect(e.message).toBe('Action result status is error: result is not streamable');
}
});
it('should throw an error if action result status is "error"', async () => {
const actionsClient = actionsClientMock.create();
actionsClient.execute.mockImplementationOnce(
jest.fn().mockResolvedValue({
status: 'error',
message: 'Error message',
serviceMessage: 'Service error message',
})
);
testProps.actionsClient = actionsClient;
await expect(
executeAction({
...testProps,
actionsClient,
connectorId: '12345',
})
).rejects.toThrowError('Action result status is error: Error message - Service error message');
});
it('should throw an error if content of response data is not a string or streamable', async () => {
const actionsClient = actionsClientMock.create();
actionsClient.execute.mockImplementationOnce(
jest.fn().mockResolvedValue({
status: 'ok',
data: {
message: 12345,
},
})
);
testProps.actionsClient = actionsClient;
await expect(
executeAction({
...testProps,
actionsClient,
connectorId: '12345',
})
).rejects.toThrowError('Action result status is error: result is not streamable');
});
});

View file

@ -1,94 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { get } from 'lodash/fp';
import { ActionsClient } from '@kbn/actions-plugin/server';
import { PassThrough, Readable } from 'stream';
import { Logger } from '@kbn/core/server';
import { PublicMethodsOf } from '@kbn/utility-types';
import { handleStreamStorage } from './parse_stream';
export interface Props {
onLlmResponse?: (content: string) => Promise<void>;
abortSignal?: AbortSignal;
actionsClient: PublicMethodsOf<ActionsClient>;
connectorId: string;
params: InvokeAIActionsParams;
actionTypeId: string;
logger: Logger;
}
export interface StaticResponse {
connector_id: string;
data: string;
status: string;
}
interface InvokeAIActionsParams {
subActionParams: {
messages: Array<{ role: string; content: string }>;
model?: string;
n?: number;
stop?: string | string[] | null;
stopSequences?: string[];
temperature?: number;
};
subAction: 'invokeAI' | 'invokeStream';
}
export const executeAction = async ({
onLlmResponse,
actionsClient,
params,
connectorId,
actionTypeId,
logger,
abortSignal,
}: Props): Promise<StaticResponse | Readable> => {
const actionResult = await actionsClient.execute({
actionId: connectorId,
params: {
subAction: params.subAction,
subActionParams: {
...params.subActionParams,
signal: abortSignal,
},
},
});
if (actionResult.status === 'error') {
throw new Error(
`Action result status is error: ${actionResult?.message} - ${actionResult?.serviceMessage}`
);
}
const content = get('data.message', actionResult);
if (typeof content === 'string') {
if (onLlmResponse) {
await onLlmResponse(content);
}
return {
connector_id: connectorId,
data: content, // the response from the actions framework
status: 'ok',
};
}
const readable = get('data', actionResult) as Readable;
if (typeof readable?.read !== 'function') {
throw new Error('Action result status is error: result is not streamable');
}
// do not await, blocks stream for UI
handleStreamStorage({
actionTypeId,
onMessageSent: onLlmResponse,
logger,
responseStream: readable,
abortSignal,
}).catch(() => {});
return readable.pipe(new PassThrough());
};

View file

@ -74,35 +74,24 @@ describe('streamGraph', () => {
describe('OpenAI Function Agent streaming', () => {
it('should execute the graph in streaming mode - OpenAI + isOssModel = false', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_end',
data: {
output: {
generations: [
[{ generationInfo: { finish_reason: 'stop' }, text: 'final message' }],
],
},
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {
output: {
generations: [
[{ generationInfo: { finish_reason: 'stop' }, text: 'final message' }],
],
},
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
tags: [AGENT_NODE_TAG],
};
},
});
const response = await streamGraph(requestArgs);
@ -119,33 +108,22 @@ describe('streamGraph', () => {
});
it('on_llm_end events with finish_reason != stop should not end the stream', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]],
},
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]],
},
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
tags: [AGENT_NODE_TAG],
};
},
});
const response = await streamGraph(requestArgs);
@ -158,33 +136,22 @@ describe('streamGraph', () => {
});
it('on_llm_end events without a finish_reason should end the stream', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: {}, text: 'final message' }]],
},
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: {}, text: 'final message' }]],
},
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
tags: [AGENT_NODE_TAG],
};
},
});
const response = await streamGraph(requestArgs);
@ -201,33 +168,22 @@ describe('streamGraph', () => {
});
it('on_llm_end events is called with chunks if there is no final text value', async () => {
mockStreamEvents.mockReturnValue({
next: jest
.fn()
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
},
done: false,
})
.mockResolvedValueOnce({
value: {
name: 'ActionsClientChatOpenAI',
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: {}, text: '' }]],
},
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: {}, text: '' }]],
},
tags: [AGENT_NODE_TAG],
},
})
.mockResolvedValue({
done: true,
}),
return: jest.fn(),
tags: [AGENT_NODE_TAG],
};
},
});
const response = await streamGraph(requestArgs);
@ -242,6 +198,28 @@ describe('streamGraph', () => {
);
});
});
it('on_llm_end does not call handleStreamEnd if generations is undefined', async () => {
mockStreamEvents.mockReturnValue({
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {},
tags: [AGENT_NODE_TAG],
};
},
});
const response = await streamGraph(requestArgs);
expect(response).toBe(mockResponseWithHeaders);
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
expect(mockOnLlmResponse).not.toHaveBeenCalled();
});
});
describe('Tool Calling Agent and Structured Chat Agent streaming', () => {
@ -330,7 +308,7 @@ describe('streamGraph', () => {
await expectConditions(response);
});
it('should execute the graph in streaming mode - OpenAI + isOssModel = false', async () => {
it('should execute the graph in streaming mode - OpenAI + isOssModel = true', async () => {
const mockAssistantGraphAsyncIterator = {
streamEvents: () => mockAsyncIterator,
} as unknown as DefaultAssistantGraph;

View file

@ -9,7 +9,6 @@ import agent, { Span } from 'elastic-apm-node';
import type { Logger } from '@kbn/logging';
import { TelemetryTracer } from '@kbn/langchain/server/tracers/telemetry';
import { streamFactory, StreamResponseWithHeaders } from '@kbn/ml-response-stream/server';
import { transformError } from '@kbn/securitysolution-es-utils';
import type { KibanaRequest } from '@kbn/core-http-server';
import type { ExecuteConnectorRequestBody, TraceData } from '@kbn/elastic-assistant-common';
import { APMTracer } from '@kbn/langchain/server/tracers/apm';
@ -126,7 +125,6 @@ export const streamGraph = async ({
// Stream is from openai functions agent
let finalMessage = '';
let conversationId: string | undefined;
const stream = assistantGraph.streamEvents(inputs, {
callbacks: [
apmTracer,
@ -139,63 +137,37 @@ export const streamGraph = async ({
version: 'v1',
});
const processEvent = async () => {
try {
const { value, done } = await stream.next();
if (done) return;
const event = value;
// only process events that are part of the agent run
if ((event.tags || []).includes(AGENT_NODE_TAG)) {
if (event.name === 'ActionsClientChatOpenAI') {
if (event.event === 'on_llm_stream') {
const chunk = event.data?.chunk;
const msg = chunk.message;
if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) {
// I don't think we hit this anymore because of our check for AGENT_NODE_TAG
// however, no harm to keep it in
/* empty */
} else if (!didEnd) {
push({ payload: msg.content, type: 'content' });
finalMessage += msg.content;
}
} else if (event.event === 'on_llm_end' && !didEnd) {
const generation = event.data.output?.generations[0][0];
if (
// no finish_reason means the stream was aborted
!generation?.generationInfo?.finish_reason ||
generation?.generationInfo?.finish_reason === 'stop'
) {
handleStreamEnd(
generation?.text && generation?.text.length ? generation?.text : finalMessage
);
}
}
for await (const { event, data, tags } of stream) {
if ((tags || []).includes(AGENT_NODE_TAG)) {
if (event === 'on_llm_stream') {
const chunk = data?.chunk;
const msg = chunk.message;
if (msg?.tool_call_chunks && msg?.tool_call_chunks.length > 0) {
// I don't think we hit this anymore because of our check for AGENT_NODE_TAG
// however, no harm to keep it in
/* empty */
} else if (!didEnd) {
push({ payload: msg.content, type: 'content' });
finalMessage += msg.content;
}
}
void processEvent();
} catch (err) {
// if I throw an error here, it crashes the server. Not sure how to get around that.
// If I put await on this function the error works properly, but when there is not an error
// it waits for the entire stream to complete before resolving
const error = transformError(err);
if (error.message === 'AbortError') {
// user aborted the stream, we must end it manually here
return handleStreamEnd(finalMessage);
if (event === 'on_llm_end' && !didEnd) {
const generation = data.output?.generations[0][0];
if (
// if generation is null, an error occurred - do nothing and let error handling complete the stream
generation != null &&
// no finish_reason means the stream was aborted
(!generation?.generationInfo?.finish_reason ||
generation?.generationInfo?.finish_reason === 'stop')
) {
handleStreamEnd(
generation?.text && generation?.text.length ? generation?.text : finalMessage
);
}
}
logger.error(`Error streaming from LangChain: ${error.message}`);
if (conversationId) {
push({ payload: `Conversation id: ${conversationId}`, type: 'content' });
}
push({ payload: error.message, type: 'content' });
handleStreamEnd(error.message, true);
}
};
// Start processing events, do not await! Return `responseWithHeaders` immediately
void processEvent();
}
return responseWithHeaders;
};

View file

@ -1,168 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Readable, Transform } from 'stream';
import { loggerMock } from '@kbn/logging-mocks';
import { handleStreamStorage } from './parse_stream';
import { EventStreamCodec } from '@smithy/eventstream-codec';
import { fromUtf8, toUtf8 } from '@smithy/util-utf8';
function createStreamMock() {
const transform: Transform = new Transform({});
return {
write: (data: unknown) => {
transform.push(data);
},
fail: () => {
transform.emit('error', new Error('Stream failed'));
transform.end();
},
transform,
complete: () => {
transform.end();
},
};
}
const mockLogger = loggerMock.create();
const onMessageSent = jest.fn();
describe('handleStreamStorage', () => {
beforeEach(() => {
jest.resetAllMocks();
});
let stream: ReturnType<typeof createStreamMock>;
const chunk = {
object: 'chat.completion.chunk',
choices: [
{
delta: {
content: 'Single.',
},
},
],
};
let defaultProps = {
responseStream: jest.fn() as unknown as Readable,
actionTypeId: '.gen-ai',
onMessageSent,
logger: mockLogger,
};
describe('OpenAI stream', () => {
beforeEach(() => {
stream = createStreamMock();
stream.write(`data: ${JSON.stringify(chunk)}`);
defaultProps = {
responseStream: stream.transform,
actionTypeId: '.gen-ai',
onMessageSent,
logger: mockLogger,
};
});
it('saves the final string successful streaming event', async () => {
stream.complete();
await handleStreamStorage(defaultProps);
expect(onMessageSent).toHaveBeenCalledWith('Single.');
});
it('saves the error message on a failed streaming event', async () => {
const tokenPromise = handleStreamStorage(defaultProps);
stream.fail();
await expect(tokenPromise).resolves.not.toThrow();
expect(onMessageSent).toHaveBeenCalledWith(
`An error occurred while streaming the response:\n\nStream failed`
);
});
});
describe('Bedrock stream', () => {
beforeEach(() => {
stream = createStreamMock();
stream.write(encodeBedrockResponse('Simple.'));
defaultProps = {
responseStream: stream.transform,
actionTypeId: '.gen-ai',
onMessageSent,
logger: mockLogger,
};
});
it('saves the final string successful streaming event', async () => {
stream.complete();
await handleStreamStorage({ ...defaultProps, actionTypeId: '.bedrock' });
expect(onMessageSent).toHaveBeenCalledWith('Simple.');
});
it('saves the error message on a failed streaming event', async () => {
const tokenPromise = handleStreamStorage({ ...defaultProps, actionTypeId: '.bedrock' });
stream.fail();
await expect(tokenPromise).resolves.not.toThrow();
expect(onMessageSent).toHaveBeenCalledWith(
`An error occurred while streaming the response:\n\nStream failed`
);
});
});
describe('Gemini stream', () => {
beforeEach(() => {
stream = createStreamMock();
const payload = {
candidates: [
{
content: {
parts: [
{
text: 'Single.',
},
],
},
},
],
};
stream.write(`data: ${JSON.stringify(payload)}`);
defaultProps = {
responseStream: stream.transform,
actionTypeId: '.gemini',
onMessageSent,
logger: mockLogger,
};
});
it('saves the final string successful streaming event', async () => {
stream.complete();
await handleStreamStorage(defaultProps);
expect(onMessageSent).toHaveBeenCalledWith('Single.');
});
it('saves the error message on a failed streaming event', async () => {
const tokenPromise = handleStreamStorage(defaultProps);
stream.fail();
await expect(tokenPromise).resolves.not.toThrow();
expect(onMessageSent).toHaveBeenCalledWith(
`An error occurred while streaming the response:\n\nStream failed`
);
});
});
});
function encodeBedrockResponse(completion: string) {
return new EventStreamCodec(toUtf8, fromUtf8).encode({
headers: {},
body: Uint8Array.from(
Buffer.from(
JSON.stringify({
bytes: Buffer.from(
JSON.stringify({
type: 'content_block_delta',
index: 0,
delta: { type: 'text_delta', text: completion },
})
).toString('base64'),
})
)
),
});
}

View file

@ -1,118 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { Readable } from 'stream';
import { Logger } from '@kbn/core/server';
import { parseBedrockStream, parseGeminiResponse } from '@kbn/langchain/server';
type StreamParser = (
responseStream: Readable,
logger: Logger,
abortSignal?: AbortSignal,
tokenHandler?: (token: string) => void
) => Promise<string>;
export const handleStreamStorage = async ({
abortSignal,
responseStream,
actionTypeId,
onMessageSent,
logger,
}: {
abortSignal?: AbortSignal;
responseStream: Readable;
actionTypeId: string;
onMessageSent?: (content: string) => void;
logger: Logger;
}): Promise<void> => {
try {
const parser =
actionTypeId === '.bedrock'
? parseBedrockStream
: actionTypeId === '.gemini'
? parseGeminiStream
: parseOpenAIStream;
const parsedResponse = await parser(responseStream, logger, abortSignal);
if (onMessageSent) {
onMessageSent(parsedResponse);
}
} catch (e) {
if (onMessageSent) {
onMessageSent(`An error occurred while streaming the response:\n\n${e.message}`);
}
}
};
const parseOpenAIStream: StreamParser = async (stream, logger, abortSignal) => {
let responseBody = '';
stream.on('data', (chunk) => {
responseBody += chunk.toString();
});
return new Promise((resolve, reject) => {
stream.on('end', () => {
resolve(parseOpenAIResponse(responseBody));
});
stream.on('error', (err) => {
reject(err);
});
if (abortSignal) {
abortSignal.addEventListener('abort', () => {
stream.destroy();
resolve(parseOpenAIResponse(responseBody));
});
}
});
};
const parseOpenAIResponse = (responseBody: string) =>
responseBody
.split('\n')
.filter((line) => {
return line.startsWith('data: ') && !line.endsWith('[DONE]');
})
.map((line) => {
return JSON.parse(line.replace('data: ', ''));
})
.filter(
(
line
): line is {
choices: Array<{
delta: { content?: string; function_call?: { name?: string; arguments: string } };
}>;
} => {
return (
'object' in line && line.object === 'chat.completion.chunk' && line.choices.length > 0
);
}
)
.reduce((prev, line) => {
const msg = line.choices[0].delta;
return prev + (msg.content || '');
}, '');
export const parseGeminiStream: StreamParser = async (stream, logger, abortSignal) => {
let responseBody = '';
stream.on('data', (chunk) => {
responseBody += chunk.toString();
});
return new Promise((resolve, reject) => {
stream.on('end', () => {
resolve(parseGeminiResponse(responseBody));
});
stream.on('error', (err) => {
reject(err);
});
if (abortSignal) {
abortSignal.addEventListener('abort', () => {
stream.destroy();
logger.info('Gemini stream parsing was aborted.');
resolve(parseGeminiResponse(responseBody));
});
}
});
};

View file

@ -30,19 +30,7 @@ const actionsClient = actionsClientMock.create();
jest.mock('../lib/build_response', () => ({
buildResponse: jest.fn().mockImplementation((x) => x),
}));
jest.mock('../lib/executor', () => ({
executeAction: jest.fn().mockImplementation(async ({ connectorId }) => {
if (connectorId === 'mock-connector-id') {
return {
connector_id: 'mock-connector-id',
data: mockActionResponse,
status: 'ok',
};
} else {
throw new Error('simulated error');
}
}),
}));
const mockStream = jest.fn().mockImplementation(() => new PassThrough());
const mockLangChainExecute = langChainExecute as jest.Mock;
const mockAppendAssistantMessageToConversation = appendAssistantMessageToConversation as jest.Mock;